""" def plot_actual_output_residual( self, order_residuals: bool = False, row_size: int | float = 2, col_size: int | float = 4, norm_samples: bool = False, combine_dims: bool = True, figure_kwargs: SubplotsKwargs | None = None, subplot_kwargs: dict | None = None, gof_kwargs: dict | None = None, ) -> tuple[plt.Figure, AxesArray]: Note: transform samples in dataloader definitions beforehand if you want to change data Parameters: row_size: col_size: figure_kwargs: subplot_kwargs: gof_kwargs: ndims = self.data_tuples[0][0].size(-1) colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] fig, axes = self._create_subplots( rows=len(self.dataloaders), cols=1 if (norm_samples or combine_dims) else ndims, row_size=row_size, col_size=col_size, figure_kwargs=figure_kwargs, ) subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True): actual, output, loader_label = data_tuple if norm_samples: actual = actual.norm(dim=1, keepdim=True) output = output.norm(dim=1, keepdim=True) for dim in range(ndims): ax = axes_row[0 if combine_dims else dim] dim_color = colors[dim % len(colors)] group_name = "norm" if norm_samples else f"$d_{dim}$" dim_actual = actual[:, dim] dim_output = output[:, dim] residuals = dim_actual - dim_output X, Y = dim_actual, residuals if order_residuals: X = range(1, residuals.size(0)+1) Y = residuals[residuals.argsort()] ax.scatter( X, Y, color=dim_color, label=group_name, **subplot_kwargs, ) # plot goodness of fit line m, b = self._lstsq_dim(dim_actual, dim_output) ax.plot( dim_actual, m * dim_actual + b, color=dim_color, label=f"GoF {group_name}", **gof_kwargs, ) add_plot_context = ( norm_samples # dim=0 implicit b/c we break or not combine_dims # add to every subplot across grid or combine_dims and dim == ndims-1 # wait for last dim ) # always exec plot logic if not combining, o/w just once if add_plot_context: # compare residuals to y=0 ax.axhline(y=0, c="black", alpha=0.2) ax.set_title(f"[{loader_label}] Prediction residuals") ax.set_xlabel("actual") ax.set_ylabel("residual") # transpose legend layout for more natural view if norm_samples or not combine_dims: ax.legend() else: handles, labels = self.get_transposed_handles_labels(ax) ax.legend(handles, labels, ncols=2) # break dimension loop if collapsed by norm if norm_samples: break return fig, axes """ from functools import partial from collections.abc import Callable import numpy as np import torch import matplotlib.pyplot as plt from torch import Tensor from torch.utils.data import DataLoader from trainlib.trainer import Trainer from trainlib.estimator import EstimatorKwargs from trainlib.utils.type import AxesArray, SubplotsKwargs type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None] type ContextFn = Callable[[plt.Axes, str], None] class Plotter[B, K: EstimatorKwargs]: """ TODOs: - best fit lines for plots and residuals (compare to ideal lines in each case) - show val options across columns; preview how val is changing across natural training, and what the best will look (so plot like uniform intervals broken over the training epochs at 0, 50, 100, 150, ... and highlight the best one, even if that's not actually the single best epoch) """ def __init__( self, trainer: Trainer[..., K], dataloaders: list[DataLoader], batch_estimator_map: Callable[[B, Trainer], ...], estimator_to_output_map: Callable[[K], ...], dataloader_labels: list[str] | None = None, ) -> None: self.trainer = trainer self.dataloaders = dataloaders self.dataloader_labels = ( dataloader_labels or list(map(str, range(1, len(dataloaders)+1))) ) self.batch_estimator_map = batch_estimator_map self.estimator_to_output_map = estimator_to_output_map self._outputs: list[list[Tensor]] | None = None self._metrics: list[list[dict[str, float]]] | None = None self._batch_outputs_fn = partial( self.trainer.get_batch_outputs, batch_estimator_map=batch_estimator_map ) self._batch_metrics_fn = partial( self.trainer.get_batch_metrics, batch_estimator_map=batch_estimator_map ) self._data_tuples = None @property def outputs(self) -> list[list[Tensor]]: if self._outputs is None: self._outputs = [ list(map(self._batch_outputs_fn, loader)) for loader in self.dataloaders ] return self._outputs @property def metrics(self) -> list[list[dict[str, float]]]: if self._metrics is None: self._metrics = [ list(map(self._batch_metrics_fn, loader)) for loader in self.dataloaders ] return self._metrics @property def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]: """ Produce data items; to be cached. Zip later with axes """ if self._data_tuples is not None: return self._data_tuples data_tuples = [] for i, loader in enumerate(self.dataloaders): label = self.dataloader_labels[i] batch = next(iter(loader)) est_kwargs = self.batch_estimator_map(batch, self.trainer) actual = self.estimator_to_output_map(est_kwargs) output = self._batch_outputs_fn(batch) data_tuples.append((actual, output, label)) self._data_tuples = data_tuples return self._data_tuples def get_transposed_handles_labels( self, ax: plt.Axes ) -> tuple[list, list]: # transpose legend layout for more natural view handles, labels = ax.get_legend_handles_labels() handles = handles[::2] + handles[1::2] labels = labels[::2] + labels[1::2] return handles, labels def _lstsq_dim( self, dim_actual: Tensor, dim_output: Tensor ) -> tuple[Tensor, Tensor]: A = torch.stack( [dim_actual, torch.ones_like(dim_actual)], dim=1 ) m, b = torch.linalg.lstsq(A, dim_output).solution return m, b def _prepare_figure_kwargs( self, rows: int, cols: int, row_size: int | float = 2, col_size: int | float = 4, figure_kwargs: SubplotsKwargs | None = None, ) -> SubplotsKwargs: """ """ default_figure_kwargs = { "sharex": True, "figsize": (col_size*cols, row_size*rows), } figure_kwargs = { **default_figure_kwargs, **(figure_kwargs or {}), } return figure_kwargs def _prepare_subplot_kwargs( self, subplot_kwargs: dict | None = None, ) -> dict: """ """ default_subplot_kwargs = {} subplot_kwargs = { **default_subplot_kwargs, **(subplot_kwargs or {}), } return subplot_kwargs def _prepare_gof_kwargs( self, gof_kwargs: dict | None = None, ) -> dict: """ """ default_gof_kwargs = { "alpha": 0.5 } gof_kwargs = { **default_gof_kwargs, **(gof_kwargs or {}), } return gof_kwargs def _create_subplots( self, rows: int, cols: int, row_size: int | float = 2, col_size: int | float = 4, figure_kwargs: SubplotsKwargs | None = None, ) -> tuple[plt.Figure, AxesArray]: """ """ figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs( rows, cols, row_size=row_size, col_size=col_size, figure_kwargs=figure_kwargs, ) fig, axes = plt.subplots( rows, cols, squeeze=False, **figure_kwargs, ) # ty:ignore[no-matching-overload] return fig, axes def _plot_base( self, subplot_fn: SubplotFn, context_fn: ContextFn, row_size: int | float = 2, col_size: int | float = 4, norm_samples: bool = False, combine_dims: bool = True, figure_kwargs: SubplotsKwargs | None = None, ) -> tuple[plt.Figure, AxesArray]: """ Note: transform samples in dataloader definitions beforehand if you want to change data Parameters: row_size: col_size: figure_kwargs: subplot_kwargs: gof_kwargs: """ ndims = self.data_tuples[0][0].size(-1) fig, axes = self._create_subplots( rows=len(self.dataloaders), cols=1 if (norm_samples or combine_dims) else ndims, row_size=row_size, col_size=col_size, figure_kwargs=figure_kwargs, ) for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True): actual, output, loader_label = data_tuple if norm_samples: actual = actual.norm(dim=1, keepdim=True) output = output.norm(dim=1, keepdim=True) for dim in range(ndims): ax = axes_row[0 if combine_dims else dim] subplot_fn(ax, dim, actual, output) add_plot_context = ( norm_samples # dim=0 implicit b/c we break or not combine_dims # add to every subplot across grid or combine_dims and dim == ndims-1 # wait for last dim ) # always exec plot logic if not combining, o/w exec just once if add_plot_context: context_fn(ax, loader_label) # transpose legend layout for more natural view if norm_samples or not combine_dims: ax.legend() else: handles, labels = self.get_transposed_handles_labels(ax) ax.legend(handles, labels, ncols=2) # break dimension loop if collapsed by norm if norm_samples: break return fig, axes def plot_actual_output( self, row_size: int | float = 2, col_size: int | float = 4, norm_samples: bool = False, combine_dims: bool = True, figure_kwargs: SubplotsKwargs | None = None, subplot_kwargs: dict | None = None, gof_kwargs: dict | None = None, ) -> tuple[plt.Figure, AxesArray]: """ Plot residual distribution. """ subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] def subplot_fn( ax: plt.Axes, dim: int, actual: Tensor, output: Tensor ) -> None: dim_color = colors[dim % len(colors)] group_name = "norm" if norm_samples else f"$d_{dim}$" dim_actual = actual[:, dim] dim_output = output[:, dim] ax.scatter( dim_actual, dim_output, color=dim_color, label=group_name, **subplot_kwargs, ) # plot goodness of fit line m, b = self._lstsq_dim(dim_actual, dim_output) ax.plot( dim_actual, m * dim_actual + b, color=dim_color, label=f"GoF {group_name}", **gof_kwargs, ) def context_fn(ax: plt.Axes, loader_label: str) -> None: # plot perfect prediction reference line, y=x ax.plot( [0, 1], [0, 1], transform=ax.transAxes, c="black", alpha=0.2, ) ax.set_title( f"[{loader_label}] True labels vs Predictions" ) ax.set_xlabel("actual") ax.set_ylabel("output") return self._plot_base( subplot_fn, context_fn, row_size=row_size, col_size=col_size, norm_samples=norm_samples, combine_dims=combine_dims, figure_kwargs=figure_kwargs, ) def plot_actual_output_residual( self, row_size: int | float = 2, col_size: int | float = 4, order_residuals: bool = False, norm_samples: bool = False, combine_dims: bool = True, figure_kwargs: SubplotsKwargs | None = None, subplot_kwargs: dict | None = None, gof_kwargs: dict | None = None, ) -> tuple[plt.Figure, AxesArray]: """ Plot prediction residuals. """ subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] def subplot_fn( ax: plt.Axes, dim: int, actual: Tensor, output: Tensor ) -> None: dim_color = colors[dim % len(colors)] group_name = "norm" if norm_samples else f"$d_{dim}$" dim_actual = actual[:, dim] dim_output = output[:, dim] residuals = dim_actual - dim_output X, Y = dim_actual, residuals if order_residuals: X = range(1, residuals.size(0)+1) Y = residuals[residuals.argsort()] ax.scatter( X, Y, color=dim_color, label=group_name, **subplot_kwargs, ) # plot goodness of fit line if not order_residuals: m, b = self._lstsq_dim(dim_actual, residuals) ax.plot( dim_actual, m * dim_actual + b, color=dim_color, label=f"GoF {group_name}", **gof_kwargs, ) def context_fn(ax: plt.Axes, loader_label: str) -> None: # compare residuals to y=0 ax.axhline(y=0, c="black", alpha=0.2) ax.set_title(f"[{loader_label}] Prediction residuals") ax.set_xlabel("actual") ax.set_ylabel("residual") return self._plot_base( subplot_fn, context_fn, row_size=row_size, col_size=col_size, norm_samples=norm_samples, combine_dims=combine_dims, figure_kwargs=figure_kwargs, ) def plot_actual_output_residual_dist( self, row_size: int | float = 2, col_size: int | float = 4, norm_samples: bool = False, combine_dims: bool = True, figure_kwargs: SubplotsKwargs | None = None, subplot_kwargs: dict | None = None, gof_kwargs: dict | None = None, ) -> tuple[plt.Figure, AxesArray]: """ Plot residual distribution. """ subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs) gof_kwargs = self._prepare_gof_kwargs(gof_kwargs) colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] def subplot_fn( ax: plt.Axes, dim: int, actual: Tensor, output: Tensor ) -> None: dim_color = colors[dim % len(colors)] group_name = "norm" if norm_samples else f"$d_{dim}$" dim_actual = actual[:, dim] dim_output = output[:, dim] N = dim_actual.size(0) residuals = dim_actual - dim_output _, _, patches = ax.hist( residuals.abs(), bins=int(np.sqrt(N)), density=True, alpha=0.3, color=dim_color, label=group_name, **subplot_kwargs ) mu = residuals.abs().mean().item() ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$") def context_fn(ax: plt.Axes, loader_label: str) -> None: ax.set_title(f"[{loader_label}] Residual distribution") ax.set_xlabel("actual") ax.set_ylabel("residual (density)") return self._plot_base( subplot_fn, context_fn, row_size=row_size, col_size=col_size, norm_samples=norm_samples, combine_dims=combine_dims, figure_kwargs=figure_kwargs, )