move plotting utilities into core Plotter class

This commit is contained in:
2026-03-13 19:58:08 -07:00
parent 95d7bc68ce
commit 30fb6fa107
4 changed files with 334 additions and 1 deletions

0
trainlib/diagnostic.py Normal file
View File

283
trainlib/plotter.py Normal file
View File

@@ -0,0 +1,283 @@
from typing import Self
from functools import partial
from collections.abc import Callable, Generator
import numpy as np
import matplotlib.pyplot as plt
from torch import Tensor
from numpy.typing import NDArray
from torch.utils.data import DataLoader
from trainlib.trainer import Trainer
from trainlib.estimator import EstimatorKwargs
from trainlib.utils.type import AxesArray, SubplotsKwargs
class Plotter[B, K: EstimatorKwargs]:
def __init__(
self,
trainer: Trainer[..., K],
dataloaders: list[DataLoader],
batch_estimator_map: Callable[[B, Trainer], ...],
estimator_to_output_map: Callable[[K], ...],
dataloader_labels: list[str] | None = None,
) -> None:
self.trainer = trainer
self.dataloaders = dataloaders
self.dataloader_labels = (
dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
)
self.batch_estimator_map = batch_estimator_map
self.estimator_to_output_map = estimator_to_output_map
self._outputs: list[list[Tensor]] | None = None
self._metrics: list[list[dict[str, float]]] | None = None
self._batch_outputs_fn = partial(
self.trainer.get_batch_outputs,
batch_estimator_map=self.batch_estimator_map
)
self._batch_metrics_fn = partial(
self.trainer.get_batch_metrics,
batch_estimator_map=self.batch_estimator_map
)
self._data_tuples = None
@property
def outputs(self) -> list[list[Tensor]]:
if self._outputs is None:
self._outputs = [
list(map(self._batch_outputs_fn, loader))
for loader in self.dataloaders
]
return self._outputs
@property
def metrics(self) -> list[list[dict[str, float]]]:
if self._metrics is None:
self._metrics = [
list(map(self._batch_metrics_fn, loader))
for loader in self.dataloaders
]
return self._metrics
@property
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
"""
Produce data items; to be cached. Zip later with axes
"""
if self._data_tuples is not None:
return self._data_tuples
data_tuples = []
for i, loader in enumerate(self.dataloaders):
label = self.dataloader_labels[i]
batch = next(iter(loader))
est_kwargs = self.batch_estimator_map(batch, self.trainer)
actual = self.estimator_to_output_map(est_kwargs)
output = self._batch_outputs_fn(batch)
data_tuples.append((actual, output, label))
self._data_tuples = data_tuples
return self._data_tuples
def _default_figure_kwargs(self, rows: int, cols: int) -> dict:
return {
"sharex": True,
"figsize": (4*cols, 2*rows),
"constrained_layout": True,
}
def _default_subplot_kwargs(self) -> dict:
return {}
def _create_subplots(
self,
**figure_kwargs: SubplotsKwargs,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
rows, cols = len(self.dataloaders), 1
figure_kwargs = {
**self._default_figure_kwargs(rows, cols),
**figure_kwargs,
}
fig, axes = plt.subplots(
rows, cols,
squeeze=False,
**figure_kwargs
)
return fig, axes
def plot_actual_output_dim(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Wrapper like this works fine, but it's smelly: we *don't* want @wraps,
do this method doesn't actually have this signature at runtime (it has
the dec wrapper's sig). I think the cleaner thing is to just have
internal methods (_func) like the one below, and then the main method
entry just pass that internal method through to the skeleton
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
ax.plot(
[0, 1], [0, 1],
transform=ax.transAxes,
c="black",
alpha=0.2
)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] True labels vs Predictions (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("output")
ax.legend()
return fig, axes
def plot_actual_output_residual_dim(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
actual[:, dim] - output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] Residuals (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
ax.legend()
return fig, axes
def plot_actual_output_ordered_residual_dim(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
for dim in range(actual.size(-1)):
ax.scatter(
actual[:, dim],
actual[:, dim] - output[:, dim],
label=f"$d_{dim}$",
**subplot_kwargs
)
ax.set_title(f"[{label}] Residuals (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
ax.legend()
return fig, axes
def plot_actual_output_residual_dist(
self,
figure_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs = figure_kwargs or {}
subplot_kwargs = {
**self._default_subplot_kwargs(),
**(subplot_kwargs or {}),
}
fig, axes = self._create_subplots(**figure_kwargs)
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, label = data_tuple
N = actual.size(0)
for dim in range(actual.size(-1)):
residuals = actual[:, dim] - output[:, dim]
_, _, patches = ax.hist(
residuals.abs(),
bins=int(np.sqrt(N)),
density=True,
alpha=0.2,
label=f"$d_{dim}$",
**subplot_kwargs
)
# grab color used for hist and mirror in the v-line
color = patches[0].get_facecolor()
mu = residuals.abs().mean().item()
ax.axvline(mu, linestyle=":", c=color, label=f"$\mu_{dim}$")
ax.set_title(f"[{label}] Residual distribution (dim-wise)")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
# transpose legend layout for more natural view
handles, labels = ax.get_legend_handles_labels()
handles = handles[::2] + handles[1::2]
labels = labels[::2] + labels[1::2]
ax.legend(handles, labels, ncols=2)
return fig, axes

View File

@@ -282,7 +282,8 @@ class Trainer[I, K: EstimatorKwargs]:
"""
TODO: consider making the dataloader ``collate_fn`` an explicit
parameter with a type signature that reflects ``B``, connecting the
``batch_estimator_map`` somewhere
``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader``
in-house to allow a generic around ``B``
Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The
@@ -545,6 +546,31 @@ class Trainer[I, K: EstimatorKwargs]:
def _add_summary_item(self, name: str, value: float) -> None:
self._summary[name].append((value, self._step))
def get_batch_outputs[B](
self,
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> Tensor:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
output = self.estimator(**est_kwargs)[0]
output = output.detach().cpu()
return output
def get_batch_metrics[B](
self,
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> dict[str, float]:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
metrics = self.estimator.metrics(**est_kwargs)
return metrics
def save_model(
self,
epoch: int,

View File

@@ -1,11 +1,14 @@
from typing import Any, TypedDict
from collections.abc import Callable, Iterable
import numpy as np
from torch import Tensor
from torch.utils.data.sampler import Sampler
from trainlib.dataset import BatchedDataset
# need b/c matplotlib axes are insanely stupid
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
class LoaderKwargs(TypedDict, total=False):
batch_size: int
@@ -50,3 +53,24 @@ class OptimizerKwargs(TypedDict, total=False):
capturable: bool
differentiable: bool
fused: bool | None
class SubplotsKwargs(TypedDict, total=False):
sharex: bool | str
sharey: bool | str
squeeze: bool
width_ratios: list[float]
height_ratios: list[float]
subplot_kw: dict
gridspec_kw: dict
figsize: tuple[float, float]
dpi: float
layout: str
sharex: bool | Literal["none", "all", "row", "col"] = False,
sharey: bool | Literal["none", "all", "row", "col"] = False,
squeeze: bool = True,
width_ratios: Sequence[float] | None = None,
height_ratios: Sequence[float] | None = None,
subplot_kw: dict[str, Any] | None = None,
gridspec_kw: dict[str, Any] | None = None,