From 69e8e88047639e397bc51a4668fdf022f811cfb8 Mon Sep 17 00:00:00 2001 From: smgr Date: Sat, 14 Mar 2026 22:56:58 -0700 Subject: [PATCH] refactor Plotter to use a common plot base --- trainlib/plotter.py | 579 ++++++++++++++++++++++++--------- trainlib/utils/custom.mplstyle | 77 ++--- trainlib/utils/type.py | 19 +- 3 files changed, 477 insertions(+), 198 deletions(-) diff --git a/trainlib/plotter.py b/trainlib/plotter.py index 2cd45f6..ca872d3 100644 --- a/trainlib/plotter.py +++ b/trainlib/plotter.py @@ -1,19 +1,134 @@ -from typing import Self +""" + def plot_actual_output_residual( + self, + order_residuals: bool = False, + row_size: int | float = 2, + col_size: int | float = 4, + norm_samples: bool = False, + combine_dims: bool = True, + figure_kwargs: SubplotsKwargs | None = None, + subplot_kwargs: dict | None = None, + gof_kwargs: dict | None = None, + ) -> tuple[plt.Figure, AxesArray]: + Note: transform samples in dataloader definitions beforehand if you + want to change data + + Parameters: + row_size: + col_size: + figure_kwargs: + subplot_kwargs: + gof_kwargs: + + ndims = self.data_tuples[0][0].size(-1) + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + fig, axes = self._create_subplots( + rows=len(self.dataloaders), + cols=1 if (norm_samples or combine_dims) else ndims, + row_size=row_size, + col_size=col_size, + figure_kwargs=figure_kwargs, + ) + subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) + gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) + + for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True): + actual, output, loader_label = data_tuple + + if norm_samples: + actual = actual.norm(dim=1, keepdim=True) + output = output.norm(dim=1, keepdim=True) + + for dim in range(ndims): + ax = axes_row[0 if combine_dims else dim] + dim_color = colors[dim % len(colors)] + group_name = "norm" if norm_samples else f"$d_{dim}$" + + dim_actual = actual[:, dim] + dim_output = output[:, dim] + residuals = dim_actual - dim_output + + X, Y = dim_actual, residuals + if order_residuals: + X = range(1, residuals.size(0)+1) + Y = residuals[residuals.argsort()] + + ax.scatter( + X, Y, + color=dim_color, + label=group_name, + **subplot_kwargs, + ) + + # plot goodness of fit line + m, b = self._lstsq_dim(dim_actual, dim_output) + ax.plot( + dim_actual, + m * dim_actual + b, + color=dim_color, + label=f"GoF {group_name}", + **gof_kwargs, + ) + + add_plot_context = ( + norm_samples # dim=0 implicit b/c we break + or not combine_dims # add to every subplot across grid + or combine_dims and dim == ndims-1 # wait for last dim + ) + + # always exec plot logic if not combining, o/w just once + if add_plot_context: + # compare residuals to y=0 + ax.axhline(y=0, c="black", alpha=0.2) + + ax.set_title(f"[{loader_label}] Prediction residuals") + ax.set_xlabel("actual") + ax.set_ylabel("residual") + + # transpose legend layout for more natural view + if norm_samples or not combine_dims: + ax.legend() + else: + handles, labels = self.get_transposed_handles_labels(ax) + ax.legend(handles, labels, ncols=2) + + # break dimension loop if collapsed by norm + if norm_samples: + break + + return fig, axes +""" + from functools import partial -from collections.abc import Callable, Generator +from collections.abc import Callable import numpy as np +import torch import matplotlib.pyplot as plt from torch import Tensor -from numpy.typing import NDArray from torch.utils.data import DataLoader from trainlib.trainer import Trainer from trainlib.estimator import EstimatorKwargs from trainlib.utils.type import AxesArray, SubplotsKwargs +type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None] +type ContextFn = Callable[[plt.Axes, str], None] + class Plotter[B, K: EstimatorKwargs]: + """ + TODOs: + + - best fit lines for plots and residuals (compare to ideal lines in each + case) + - show val options across columns; preview how val is changing across + natural training, and what the best will look (so plot like uniform + intervals broken over the training epochs at 0, 50, 100, 150, ... and + highlight the best one, even if that's not actually the single best + epoch) + """ + def __init__( self, trainer: Trainer[..., K], @@ -35,11 +150,11 @@ class Plotter[B, K: EstimatorKwargs]: self._batch_outputs_fn = partial( self.trainer.get_batch_outputs, - batch_estimator_map=self.batch_estimator_map + batch_estimator_map=batch_estimator_map ) self._batch_metrics_fn = partial( self.trainer.get_batch_metrics, - batch_estimator_map=self.batch_estimator_map + batch_estimator_map=batch_estimator_map ) self._data_tuples = None @@ -83,201 +198,369 @@ class Plotter[B, K: EstimatorKwargs]: data_tuples.append((actual, output, label)) self._data_tuples = data_tuples + return self._data_tuples - def _default_figure_kwargs(self, rows: int, cols: int) -> dict: - return { - "sharex": True, - "figsize": (4*cols, 2*rows), - "constrained_layout": True, - } - - def _default_subplot_kwargs(self) -> dict: - return {} - - def _create_subplots( + def get_transposed_handles_labels( self, - **figure_kwargs: SubplotsKwargs, - ) -> tuple[plt.Figure, AxesArray]: - """ - """ + ax: plt.Axes + ) -> tuple[list, list]: + # transpose legend layout for more natural view + handles, labels = ax.get_legend_handles_labels() + handles = handles[::2] + handles[1::2] + labels = labels[::2] + labels[1::2] - rows, cols = len(self.dataloaders), 1 - figure_kwargs = { - **self._default_figure_kwargs(rows, cols), - **figure_kwargs, - } - fig, axes = plt.subplots( - rows, cols, - squeeze=False, - **figure_kwargs + return handles, labels + + def _lstsq_dim( + self, + dim_actual: Tensor, + dim_output: Tensor + ) -> tuple[Tensor, Tensor]: + A = torch.stack( + [dim_actual, torch.ones_like(dim_actual)], + dim=1 ) + m, b = torch.linalg.lstsq(A, dim_output).solution - return fig, axes + return m, b - def plot_actual_output_dim( + def _prepare_figure_kwargs( self, - figure_kwargs: dict | None = None, - subplot_kwargs: dict | None = None, - ) -> tuple[plt.Figure, AxesArray]: + rows: int, + cols: int, + row_size: int | float = 2, + col_size: int | float = 4, + figure_kwargs: SubplotsKwargs | None = None, + ) -> SubplotsKwargs: """ - Wrapper like this works fine, but it's smelly: we *don't* want @wraps, - do this method doesn't actually have this signature at runtime (it has - the dec wrapper's sig). I think the cleaner thing is to just have - internal methods (_func) like the one below, and then the main method - entry just pass that internal method through to the skeleton """ - figure_kwargs = figure_kwargs or {} + default_figure_kwargs = { + "sharex": True, + "figsize": (col_size*cols, row_size*rows), + } + figure_kwargs = { + **default_figure_kwargs, + **(figure_kwargs or {}), + } + + return figure_kwargs + + def _prepare_subplot_kwargs( + self, + subplot_kwargs: dict | None = None, + ) -> dict: + """ + """ + + default_subplot_kwargs = {} subplot_kwargs = { - **self._default_subplot_kwargs(), + **default_subplot_kwargs, **(subplot_kwargs or {}), } - fig, axes = self._create_subplots(**figure_kwargs) - for ax, data_tuple in zip(axes, self.data_tuples, strict=True): - actual, output, label = data_tuple + return subplot_kwargs + def _prepare_gof_kwargs( + self, + gof_kwargs: dict | None = None, + ) -> dict: + """ + """ + + default_gof_kwargs = { + "alpha": 0.5 + } + gof_kwargs = { + **default_gof_kwargs, + **(gof_kwargs or {}), + } + + return gof_kwargs + + def _create_subplots( + self, + rows: int, + cols: int, + row_size: int | float = 2, + col_size: int | float = 4, + figure_kwargs: SubplotsKwargs | None = None, + ) -> tuple[plt.Figure, AxesArray]: + """ + """ + + figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs( + rows, + cols, + row_size=row_size, + col_size=col_size, + figure_kwargs=figure_kwargs, + ) + fig, axes = plt.subplots( + rows, + cols, + squeeze=False, + **figure_kwargs, + ) # ty:ignore[no-matching-overload] + + return fig, axes + + def _plot_base( + self, + subplot_fn: SubplotFn, + context_fn: ContextFn, + row_size: int | float = 2, + col_size: int | float = 4, + norm_samples: bool = False, + combine_dims: bool = True, + figure_kwargs: SubplotsKwargs | None = None, + ) -> tuple[plt.Figure, AxesArray]: + """ + Note: transform samples in dataloader definitions beforehand if you + want to change data + + Parameters: + row_size: + col_size: + figure_kwargs: + subplot_kwargs: + gof_kwargs: + """ + + ndims = self.data_tuples[0][0].size(-1) + fig, axes = self._create_subplots( + rows=len(self.dataloaders), + cols=1 if (norm_samples or combine_dims) else ndims, + row_size=row_size, + col_size=col_size, + figure_kwargs=figure_kwargs, + ) + + for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True): + actual, output, loader_label = data_tuple + + if norm_samples: + actual = actual.norm(dim=1, keepdim=True) + output = output.norm(dim=1, keepdim=True) + + for dim in range(ndims): + ax = axes_row[0 if combine_dims else dim] + + subplot_fn(ax, dim, actual, output) + + add_plot_context = ( + norm_samples # dim=0 implicit b/c we break + or not combine_dims # add to every subplot across grid + or combine_dims and dim == ndims-1 # wait for last dim + ) + + # always exec plot logic if not combining, o/w exec just once + if add_plot_context: + context_fn(ax, loader_label) + + # transpose legend layout for more natural view + if norm_samples or not combine_dims: + ax.legend() + else: + handles, labels = self.get_transposed_handles_labels(ax) + ax.legend(handles, labels, ncols=2) + + # break dimension loop if collapsed by norm + if norm_samples: + break + + return fig, axes + + def plot_actual_output( + self, + row_size: int | float = 2, + col_size: int | float = 4, + norm_samples: bool = False, + combine_dims: bool = True, + figure_kwargs: SubplotsKwargs | None = None, + subplot_kwargs: dict | None = None, + gof_kwargs: dict | None = None, + ) -> tuple[plt.Figure, AxesArray]: + """ + Plot residual distribution. + """ + + subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) + gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + + def subplot_fn( + ax: plt.Axes, dim: int, actual: Tensor, output: Tensor + ) -> None: + dim_color = colors[dim % len(colors)] + group_name = "norm" if norm_samples else f"$d_{dim}$" + + dim_actual = actual[:, dim] + dim_output = output[:, dim] + + ax.scatter( + dim_actual, dim_output, + color=dim_color, + label=group_name, + **subplot_kwargs, + ) + + # plot goodness of fit line + m, b = self._lstsq_dim(dim_actual, dim_output) + ax.plot( + dim_actual, m * dim_actual + b, + color=dim_color, + label=f"GoF {group_name}", + **gof_kwargs, + ) + + def context_fn(ax: plt.Axes, loader_label: str) -> None: + # plot perfect prediction reference line, y=x ax.plot( [0, 1], [0, 1], transform=ax.transAxes, c="black", - alpha=0.2 + alpha=0.2, ) - for dim in range(actual.size(-1)): - ax.scatter( - actual[:, dim], - output[:, dim], - label=f"$d_{dim}$", - **subplot_kwargs - ) - - ax.set_title(f"[{label}] True labels vs Predictions (dim-wise)") + ax.set_title( + f"[{loader_label}] True labels vs Predictions" + ) ax.set_xlabel("actual") ax.set_ylabel("output") - ax.legend() - return fig, axes + return self._plot_base( + subplot_fn, context_fn, + row_size=row_size, + col_size=col_size, + norm_samples=norm_samples, + combine_dims=combine_dims, + figure_kwargs=figure_kwargs, + ) - def plot_actual_output_residual_dim( + def plot_actual_output_residual( self, - figure_kwargs: dict | None = None, + row_size: int | float = 2, + col_size: int | float = 4, + order_residuals: bool = False, + norm_samples: bool = False, + combine_dims: bool = True, + figure_kwargs: SubplotsKwargs | None = None, subplot_kwargs: dict | None = None, + gof_kwargs: dict | None = None, ) -> tuple[plt.Figure, AxesArray]: """ + Plot prediction residuals. """ - figure_kwargs = figure_kwargs or {} - subplot_kwargs = { - **self._default_subplot_kwargs(), - **(subplot_kwargs or {}), - } + subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) + gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] - fig, axes = self._create_subplots(**figure_kwargs) - for ax, data_tuple in zip(axes, self.data_tuples, strict=True): - actual, output, label = data_tuple + def subplot_fn( + ax: plt.Axes, dim: int, actual: Tensor, output: Tensor + ) -> None: + dim_color = colors[dim % len(colors)] + group_name = "norm" if norm_samples else f"$d_{dim}$" + dim_actual = actual[:, dim] + dim_output = output[:, dim] + + residuals = dim_actual - dim_output + X, Y = dim_actual, residuals + + if order_residuals: + X = range(1, residuals.size(0)+1) + Y = residuals[residuals.argsort()] + + ax.scatter( + X, Y, + color=dim_color, + label=group_name, + **subplot_kwargs, + ) + + # plot goodness of fit line + if not order_residuals: + m, b = self._lstsq_dim(dim_actual, residuals) + ax.plot( + dim_actual, m * dim_actual + b, + color=dim_color, + label=f"GoF {group_name}", + **gof_kwargs, + ) + + def context_fn(ax: plt.Axes, loader_label: str) -> None: # compare residuals to y=0 ax.axhline(y=0, c="black", alpha=0.2) - for dim in range(actual.size(-1)): - ax.scatter( - actual[:, dim], - actual[:, dim] - output[:, dim], - label=f"$d_{dim}$", - **subplot_kwargs - ) - - ax.set_title(f"[{label}] Residuals (dim-wise)") + ax.set_title(f"[{loader_label}] Prediction residuals") ax.set_xlabel("actual") ax.set_ylabel("residual") - ax.legend() - - return fig, axes - - def plot_actual_output_ordered_residual_dim( - self, - figure_kwargs: dict | None = None, - subplot_kwargs: dict | None = None, - ) -> tuple[plt.Figure, AxesArray]: - """ - """ - - figure_kwargs = figure_kwargs or {} - subplot_kwargs = { - **self._default_subplot_kwargs(), - **(subplot_kwargs or {}), - } - - fig, axes = self._create_subplots(**figure_kwargs) - for ax, data_tuple in zip(axes, self.data_tuples, strict=True): - actual, output, label = data_tuple - - # compare residuals to y=0 - ax.axhline(y=0, c="black", alpha=0.2) - - for dim in range(actual.size(-1)): - ax.scatter( - actual[:, dim], - actual[:, dim] - output[:, dim], - label=f"$d_{dim}$", - **subplot_kwargs - ) - - ax.set_title(f"[{label}] Residuals (dim-wise)") - ax.set_xlabel("actual") - ax.set_ylabel("residual") - ax.legend() - - return fig, axes + return self._plot_base( + subplot_fn, context_fn, + row_size=row_size, + col_size=col_size, + norm_samples=norm_samples, + combine_dims=combine_dims, + figure_kwargs=figure_kwargs, + ) + def plot_actual_output_residual_dist( self, - figure_kwargs: dict | None = None, + row_size: int | float = 2, + col_size: int | float = 4, + norm_samples: bool = False, + combine_dims: bool = True, + figure_kwargs: SubplotsKwargs | None = None, subplot_kwargs: dict | None = None, + gof_kwargs: dict | None = None, ) -> tuple[plt.Figure, AxesArray]: """ + Plot residual distribution. """ - figure_kwargs = figure_kwargs or {} - subplot_kwargs = { - **self._default_subplot_kwargs(), - **(subplot_kwargs or {}), - } + subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) + gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] - fig, axes = self._create_subplots(**figure_kwargs) - for ax, data_tuple in zip(axes, self.data_tuples, strict=True): - actual, output, label = data_tuple + def subplot_fn( + ax: plt.Axes, dim: int, actual: Tensor, output: Tensor + ) -> None: + dim_color = colors[dim % len(colors)] + group_name = "norm" if norm_samples else f"$d_{dim}$" - N = actual.size(0) - for dim in range(actual.size(-1)): - residuals = actual[:, dim] - output[:, dim] + dim_actual = actual[:, dim] + dim_output = output[:, dim] - _, _, patches = ax.hist( - residuals.abs(), - bins=int(np.sqrt(N)), - density=True, - alpha=0.2, - label=f"$d_{dim}$", - **subplot_kwargs - ) + N = dim_actual.size(0) + residuals = dim_actual - dim_output - # grab color used for hist and mirror in the v-line - color = patches[0].get_facecolor() + _, _, patches = ax.hist( + residuals.abs(), + bins=int(np.sqrt(N)), + density=True, + alpha=0.3, + color=dim_color, + label=group_name, + **subplot_kwargs + ) - mu = residuals.abs().mean().item() - ax.axvline(mu, linestyle=":", c=color, label=f"$\mu_{dim}$") + mu = residuals.abs().mean().item() + ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$") - ax.set_title(f"[{label}] Residual distribution (dim-wise)") + def context_fn(ax: plt.Axes, loader_label: str) -> None: + ax.set_title(f"[{loader_label}] Residual distribution") ax.set_xlabel("actual") - ax.set_ylabel("residual") + ax.set_ylabel("residual (density)") - # transpose legend layout for more natural view - handles, labels = ax.get_legend_handles_labels() - handles = handles[::2] + handles[1::2] - labels = labels[::2] + labels[1::2] - - ax.legend(handles, labels, ncols=2) - - return fig, axes + return self._plot_base( + subplot_fn, context_fn, + row_size=row_size, + col_size=col_size, + norm_samples=norm_samples, + combine_dims=combine_dims, + figure_kwargs=figure_kwargs, + ) diff --git a/trainlib/utils/custom.mplstyle b/trainlib/utils/custom.mplstyle index 7b9558c..09f385c 100644 --- a/trainlib/utils/custom.mplstyle +++ b/trainlib/utils/custom.mplstyle @@ -1,46 +1,49 @@ -text.usetex : False -mathtext.default : regular +text.usetex : False +mathtext.default : regular -font.family : sans-serif -font.sans-serif : DejaVu Sans -font.serif : DejaVu Serif -font.cursive : DejaVu Sans -mathtext.fontset : dejavuserif -font.size : 9 -figure.titlesize : 9 -legend.fontsize : 9 -axes.titlesize : 9 -axes.labelsize : 9 -xtick.labelsize : 9 -ytick.labelsize : 9 +# testing to prevent component overlap/clipping +figure.constrained_layout.use : True -#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb']) -axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a']) +font.family : sans-serif +font.sans-serif : DejaVu Sans +font.serif : DejaVu Serif +font.cursive : DejaVu Sans +mathtext.fontset : dejavuserif +font.size : 9 +figure.titlesize : 9 +legend.fontsize : 9 +axes.titlesize : 9 +axes.labelsize : 9 +xtick.labelsize : 9 +ytick.labelsize : 9 -image.interpolation : nearest -image.resample : False -image.composite_image : True +# monobiome -d 0.45 -l 22 +axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a']) -axes.spines.left : True -axes.spines.bottom : True -axes.spines.top : False -axes.spines.right : False +image.interpolation : nearest +image.resample : False +image.composite_image : True -axes.linewidth : 1 -xtick.major.width : 1 -xtick.minor.width : 1 -ytick.major.width : 1 -ytick.minor.width : 1 +axes.spines.left : True +axes.spines.bottom : True +axes.spines.top : False +axes.spines.right : False -lines.linewidth : 1 -lines.markersize : 1 +axes.linewidth : 1 +xtick.major.width : 1 +xtick.minor.width : 1 +ytick.major.width : 1 +ytick.minor.width : 1 -savefig.dpi : 300 -savefig.format : svg -savefig.bbox : tight -savefig.pad_inches : 0.1 +lines.linewidth : 1 +lines.markersize : 1 -svg.image_inline : True -svg.fonttype : none +savefig.dpi : 300 +savefig.format : svg +savefig.bbox : tight +savefig.pad_inches : 0.1 -legend.frameon : False +svg.image_inline : True +svg.fonttype : none + +legend.frameon : False diff --git a/trainlib/utils/type.py b/trainlib/utils/type.py index acec06d..db625a3 100644 --- a/trainlib/utils/type.py +++ b/trainlib/utils/type.py @@ -1,5 +1,5 @@ from typing import Any, TypedDict -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence import numpy as np from torch import Tensor @@ -59,18 +59,11 @@ class SubplotsKwargs(TypedDict, total=False): sharex: bool | str sharey: bool | str squeeze: bool - width_ratios: list[float] - height_ratios: list[float] - subplot_kw: dict - gridspec_kw: dict + width_ratios: Sequence[float] + height_ratios: Sequence[float] + subplot_kw: dict[str, ...] + gridspec_kw: dict[str, ...] figsize: tuple[float, float] dpi: float layout: str - - sharex: bool | Literal["none", "all", "row", "col"] = False, - sharey: bool | Literal["none", "all", "row", "col"] = False, - squeeze: bool = True, - width_ratios: Sequence[float] | None = None, - height_ratios: Sequence[float] | None = None, - subplot_kw: dict[str, Any] | None = None, - gridspec_kw: dict[str, Any] | None = None, + constrained_layout: bool