diff --git a/trainlib/diagnostic.py b/trainlib/diagnostic.py new file mode 100644 index 0000000..e69de29 diff --git a/trainlib/plotter.py b/trainlib/plotter.py new file mode 100644 index 0000000..2cd45f6 --- /dev/null +++ b/trainlib/plotter.py @@ -0,0 +1,283 @@ +from typing import Self +from functools import partial +from collections.abc import Callable, Generator + +import numpy as np +import matplotlib.pyplot as plt +from torch import Tensor +from numpy.typing import NDArray +from torch.utils.data import DataLoader + +from trainlib.trainer import Trainer +from trainlib.estimator import EstimatorKwargs +from trainlib.utils.type import AxesArray, SubplotsKwargs + + +class Plotter[B, K: EstimatorKwargs]: + def __init__( + self, + trainer: Trainer[..., K], + dataloaders: list[DataLoader], + batch_estimator_map: Callable[[B, Trainer], ...], + estimator_to_output_map: Callable[[K], ...], + dataloader_labels: list[str] | None = None, + ) -> None: + self.trainer = trainer + self.dataloaders = dataloaders + self.dataloader_labels = ( + dataloader_labels or list(map(str, range(1, len(dataloaders)+1))) + ) + self.batch_estimator_map = batch_estimator_map + self.estimator_to_output_map = estimator_to_output_map + + self._outputs: list[list[Tensor]] | None = None + self._metrics: list[list[dict[str, float]]] | None = None + + self._batch_outputs_fn = partial( + self.trainer.get_batch_outputs, + batch_estimator_map=self.batch_estimator_map + ) + self._batch_metrics_fn = partial( + self.trainer.get_batch_metrics, + batch_estimator_map=self.batch_estimator_map + ) + + self._data_tuples = None + + @property + def outputs(self) -> list[list[Tensor]]: + if self._outputs is None: + self._outputs = [ + list(map(self._batch_outputs_fn, loader)) + for loader in self.dataloaders + ] + return self._outputs + + @property + def metrics(self) -> list[list[dict[str, float]]]: + if self._metrics is None: + self._metrics = [ + list(map(self._batch_metrics_fn, loader)) + for loader in self.dataloaders + ] + return self._metrics + + @property + def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]: + """ + Produce data items; to be cached. Zip later with axes + """ + + if self._data_tuples is not None: + return self._data_tuples + + data_tuples = [] + for i, loader in enumerate(self.dataloaders): + label = self.dataloader_labels[i] + + batch = next(iter(loader)) + est_kwargs = self.batch_estimator_map(batch, self.trainer) + actual = self.estimator_to_output_map(est_kwargs) + output = self._batch_outputs_fn(batch) + + data_tuples.append((actual, output, label)) + + self._data_tuples = data_tuples + return self._data_tuples + + def _default_figure_kwargs(self, rows: int, cols: int) -> dict: + return { + "sharex": True, + "figsize": (4*cols, 2*rows), + "constrained_layout": True, + } + + def _default_subplot_kwargs(self) -> dict: + return {} + + def _create_subplots( + self, + **figure_kwargs: SubplotsKwargs, + ) -> tuple[plt.Figure, AxesArray]: + """ + """ + + rows, cols = len(self.dataloaders), 1 + figure_kwargs = { + **self._default_figure_kwargs(rows, cols), + **figure_kwargs, + } + fig, axes = plt.subplots( + rows, cols, + squeeze=False, + **figure_kwargs + ) + + return fig, axes + + def plot_actual_output_dim( + self, + figure_kwargs: dict | None = None, + subplot_kwargs: dict | None = None, + ) -> tuple[plt.Figure, AxesArray]: + """ + Wrapper like this works fine, but it's smelly: we *don't* want @wraps, + do this method doesn't actually have this signature at runtime (it has + the dec wrapper's sig). I think the cleaner thing is to just have + internal methods (_func) like the one below, and then the main method + entry just pass that internal method through to the skeleton + """ + + figure_kwargs = figure_kwargs or {} + subplot_kwargs = { + **self._default_subplot_kwargs(), + **(subplot_kwargs or {}), + } + + fig, axes = self._create_subplots(**figure_kwargs) + for ax, data_tuple in zip(axes, self.data_tuples, strict=True): + actual, output, label = data_tuple + + ax.plot( + [0, 1], [0, 1], + transform=ax.transAxes, + c="black", + alpha=0.2 + ) + + for dim in range(actual.size(-1)): + ax.scatter( + actual[:, dim], + output[:, dim], + label=f"$d_{dim}$", + **subplot_kwargs + ) + + ax.set_title(f"[{label}] True labels vs Predictions (dim-wise)") + ax.set_xlabel("actual") + ax.set_ylabel("output") + ax.legend() + + return fig, axes + + def plot_actual_output_residual_dim( + self, + figure_kwargs: dict | None = None, + subplot_kwargs: dict | None = None, + ) -> tuple[plt.Figure, AxesArray]: + """ + """ + + figure_kwargs = figure_kwargs or {} + subplot_kwargs = { + **self._default_subplot_kwargs(), + **(subplot_kwargs or {}), + } + + fig, axes = self._create_subplots(**figure_kwargs) + for ax, data_tuple in zip(axes, self.data_tuples, strict=True): + actual, output, label = data_tuple + + # compare residuals to y=0 + ax.axhline(y=0, c="black", alpha=0.2) + + for dim in range(actual.size(-1)): + ax.scatter( + actual[:, dim], + actual[:, dim] - output[:, dim], + label=f"$d_{dim}$", + **subplot_kwargs + ) + + ax.set_title(f"[{label}] Residuals (dim-wise)") + ax.set_xlabel("actual") + ax.set_ylabel("residual") + ax.legend() + + return fig, axes + + def plot_actual_output_ordered_residual_dim( + self, + figure_kwargs: dict | None = None, + subplot_kwargs: dict | None = None, + ) -> tuple[plt.Figure, AxesArray]: + """ + """ + + figure_kwargs = figure_kwargs or {} + subplot_kwargs = { + **self._default_subplot_kwargs(), + **(subplot_kwargs or {}), + } + + fig, axes = self._create_subplots(**figure_kwargs) + for ax, data_tuple in zip(axes, self.data_tuples, strict=True): + actual, output, label = data_tuple + + # compare residuals to y=0 + ax.axhline(y=0, c="black", alpha=0.2) + + for dim in range(actual.size(-1)): + ax.scatter( + actual[:, dim], + actual[:, dim] - output[:, dim], + label=f"$d_{dim}$", + **subplot_kwargs + ) + + ax.set_title(f"[{label}] Residuals (dim-wise)") + ax.set_xlabel("actual") + ax.set_ylabel("residual") + ax.legend() + + return fig, axes + + def plot_actual_output_residual_dist( + self, + figure_kwargs: dict | None = None, + subplot_kwargs: dict | None = None, + ) -> tuple[plt.Figure, AxesArray]: + """ + """ + + figure_kwargs = figure_kwargs or {} + subplot_kwargs = { + **self._default_subplot_kwargs(), + **(subplot_kwargs or {}), + } + + fig, axes = self._create_subplots(**figure_kwargs) + for ax, data_tuple in zip(axes, self.data_tuples, strict=True): + actual, output, label = data_tuple + + N = actual.size(0) + for dim in range(actual.size(-1)): + residuals = actual[:, dim] - output[:, dim] + + _, _, patches = ax.hist( + residuals.abs(), + bins=int(np.sqrt(N)), + density=True, + alpha=0.2, + label=f"$d_{dim}$", + **subplot_kwargs + ) + + # grab color used for hist and mirror in the v-line + color = patches[0].get_facecolor() + + mu = residuals.abs().mean().item() + ax.axvline(mu, linestyle=":", c=color, label=f"$\mu_{dim}$") + + ax.set_title(f"[{label}] Residual distribution (dim-wise)") + ax.set_xlabel("actual") + ax.set_ylabel("residual") + + # transpose legend layout for more natural view + handles, labels = ax.get_legend_handles_labels() + handles = handles[::2] + handles[1::2] + labels = labels[::2] + labels[1::2] + + ax.legend(handles, labels, ncols=2) + + return fig, axes diff --git a/trainlib/trainer.py b/trainlib/trainer.py index d24cc84..143fdcc 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -282,7 +282,8 @@ class Trainer[I, K: EstimatorKwargs]: """ TODO: consider making the dataloader ``collate_fn`` an explicit parameter with a type signature that reflects ``B``, connecting the - ``batch_estimator_map`` somewhere + ``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader`` + in-house to allow a generic around ``B`` Note: this method attempts to implement a general scheme for passing needed items to the estimator's loss function from the dataloader. The @@ -545,6 +546,31 @@ class Trainer[I, K: EstimatorKwargs]: def _add_summary_item(self, name: str, value: float) -> None: self._summary[name].append((value, self._step)) + def get_batch_outputs[B]( + self, + batch: B, + batch_estimator_map: Callable[[B, Self], K], + ) -> Tensor: + self.estimator.eval() + + est_kwargs = batch_estimator_map(batch, self) + output = self.estimator(**est_kwargs)[0] + output = output.detach().cpu() + + return output + + def get_batch_metrics[B]( + self, + batch: B, + batch_estimator_map: Callable[[B, Self], K], + ) -> dict[str, float]: + self.estimator.eval() + + est_kwargs = batch_estimator_map(batch, self) + metrics = self.estimator.metrics(**est_kwargs) + + return metrics + def save_model( self, epoch: int, diff --git a/trainlib/utils/type.py b/trainlib/utils/type.py index c0489cb..acec06d 100644 --- a/trainlib/utils/type.py +++ b/trainlib/utils/type.py @@ -1,11 +1,14 @@ from typing import Any, TypedDict from collections.abc import Callable, Iterable +import numpy as np from torch import Tensor from torch.utils.data.sampler import Sampler from trainlib.dataset import BatchedDataset +# need b/c matplotlib axes are insanely stupid +type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]] class LoaderKwargs(TypedDict, total=False): batch_size: int @@ -50,3 +53,24 @@ class OptimizerKwargs(TypedDict, total=False): capturable: bool differentiable: bool fused: bool | None + + +class SubplotsKwargs(TypedDict, total=False): + sharex: bool | str + sharey: bool | str + squeeze: bool + width_ratios: list[float] + height_ratios: list[float] + subplot_kw: dict + gridspec_kw: dict + figsize: tuple[float, float] + dpi: float + layout: str + + sharex: bool | Literal["none", "all", "row", "col"] = False, + sharey: bool | Literal["none", "all", "row", "col"] = False, + squeeze: bool = True, + width_ratios: Sequence[float] | None = None, + height_ratios: Sequence[float] | None = None, + subplot_kw: dict[str, Any] | None = None, + gridspec_kw: dict[str, Any] | None = None,