Files
trainlib/trainlib/plotter.py

465 lines
14 KiB
Python

from functools import partial
from collections.abc import Callable
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import Tensor
from torch.utils.data import DataLoader
from trainlib.trainer import Trainer
from trainlib.estimator import EstimatorKwargs
from trainlib.utils.type import AxesArray, SubplotsKwargs
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
type ContextFn = Callable[[plt.Axes, str], None]
class Plotter[B, K: EstimatorKwargs]:
"""
TODOs:
- best fit lines for plots and residuals (compare to ideal lines in each
case)
- show val options across columns; preview how val is changing across
natural training, and what the best will look (so plot like uniform
intervals broken over the training epochs at 0, 50, 100, 150, ... and
highlight the best one, even if that's not actually the single best
epoch)
"""
def __init__(
self,
trainer: Trainer[..., K],
dataloaders: list[DataLoader],
batch_estimator_map: Callable[[B, Trainer], ...],
estimator_to_output_map: Callable[[K], ...],
dataloader_labels: list[str] | None = None,
) -> None:
self.trainer = trainer
self.dataloaders = dataloaders
self.dataloader_labels = (
dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
)
self.batch_estimator_map = batch_estimator_map
self.estimator_to_output_map = estimator_to_output_map
self._outputs: list[list[Tensor]] | None = None
self._metrics: list[list[dict[str, float]]] | None = None
self._batch_outputs_fn = partial(
self.trainer.get_batch_outputs,
batch_estimator_map=batch_estimator_map
)
self._batch_metrics_fn = partial(
self.trainer.get_batch_metrics,
batch_estimator_map=batch_estimator_map
)
self._data_tuples = None
@property
def outputs(self) -> list[list[Tensor]]:
if self._outputs is None:
self._outputs = [
list(map(self._batch_outputs_fn, loader))
for loader in self.dataloaders
]
return self._outputs
@property
def metrics(self) -> list[list[dict[str, float]]]:
if self._metrics is None:
self._metrics = [
list(map(self._batch_metrics_fn, loader))
for loader in self.dataloaders
]
return self._metrics
@property
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
"""
Produce data items; to be cached. Zip later with axes
"""
if self._data_tuples is not None:
return self._data_tuples
data_tuples = []
for i, loader in enumerate(self.dataloaders):
label = self.dataloader_labels[i]
batch = next(iter(loader))
est_kwargs = self.batch_estimator_map(batch, self.trainer)
actual = self.estimator_to_output_map(est_kwargs)
output = self._batch_outputs_fn(batch)
data_tuples.append((actual, output, label))
self._data_tuples = data_tuples
return self._data_tuples
def get_transposed_handles_labels(
self,
ax: plt.Axes
) -> tuple[list, list]:
# transpose legend layout for more natural view
handles, labels = ax.get_legend_handles_labels()
handles = handles[::2] + handles[1::2]
labels = labels[::2] + labels[1::2]
return handles, labels
def _lstsq_dim(
self,
dim_actual: Tensor,
dim_output: Tensor
) -> tuple[Tensor, Tensor]:
A = torch.stack(
[dim_actual, torch.ones_like(dim_actual)],
dim=1
)
m, b = torch.linalg.lstsq(A, dim_output).solution
return m, b
def _prepare_figure_kwargs(
self,
rows: int,
cols: int,
row_size: int | float = 2,
col_size: int | float = 4,
figure_kwargs: SubplotsKwargs | None = None,
) -> SubplotsKwargs:
"""
"""
default_figure_kwargs = {
"sharex": True,
"figsize": (col_size*cols, row_size*rows),
}
figure_kwargs = {
**default_figure_kwargs,
**(figure_kwargs or {}),
}
return figure_kwargs
def _prepare_subplot_kwargs(
self,
subplot_kwargs: dict | None = None,
) -> dict:
"""
"""
default_subplot_kwargs = {}
subplot_kwargs = {
**default_subplot_kwargs,
**(subplot_kwargs or {}),
}
return subplot_kwargs
def _prepare_gof_kwargs(
self,
gof_kwargs: dict | None = None,
) -> dict:
"""
"""
default_gof_kwargs = {
"alpha": 0.5
}
gof_kwargs = {
**default_gof_kwargs,
**(gof_kwargs or {}),
}
return gof_kwargs
def _create_subplots(
self,
rows: int,
cols: int,
row_size: int | float = 2,
col_size: int | float = 4,
figure_kwargs: SubplotsKwargs | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
"""
figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs(
rows,
cols,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
fig, axes = plt.subplots(
rows,
cols,
squeeze=False,
**figure_kwargs,
) # ty:ignore[no-matching-overload]
return fig, axes
def _plot_base(
self,
subplot_fn: SubplotFn,
context_fn: ContextFn,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Note: transform samples in dataloader definitions beforehand if you
want to change data
Parameters:
row_size:
col_size:
figure_kwargs:
subplot_kwargs:
gof_kwargs:
"""
ndims = self.data_tuples[0][0].size(-1)
fig, axes = self._create_subplots(
rows=len(self.dataloaders),
cols=1 if (norm_samples or combine_dims) else ndims,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
actual, output, loader_label = data_tuple
if norm_samples:
actual = actual.norm(dim=1, keepdim=True)
output = output.norm(dim=1, keepdim=True)
for dim in range(ndims):
ax = axes_row[0 if combine_dims else dim]
subplot_fn(ax, dim, actual, output)
add_plot_context = (
norm_samples # dim=0 implicit b/c we break
or not combine_dims # add to every subplot across grid
or combine_dims and dim == ndims-1 # wait for last dim
)
# always exec plot logic if not combining, o/w exec just once
if add_plot_context:
context_fn(ax, loader_label)
# transpose legend layout for more natural view
if norm_samples or not combine_dims:
ax.legend()
else:
handles, labels = self.get_transposed_handles_labels(ax)
ax.legend(handles, labels, ncols=2)
# break dimension loop if collapsed by norm
if norm_samples:
break
return fig, axes
def plot_actual_output(
self,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot residual distribution.
"""
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
ax.scatter(
dim_actual, dim_output,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
m, b = self._lstsq_dim(dim_actual, dim_output)
ax.plot(
dim_actual, m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
def context_fn(ax: plt.Axes, loader_label: str) -> None:
# plot perfect prediction reference line, y=x
ax.plot(
[0, 1], [0, 1],
transform=ax.transAxes,
c="black",
alpha=0.2,
)
ax.set_title(
f"[{loader_label}] True labels vs Predictions"
)
ax.set_xlabel("actual")
ax.set_ylabel("output")
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)
def plot_actual_output_residual(
self,
row_size: int | float = 2,
col_size: int | float = 4,
order_residuals: bool = False,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot prediction residuals.
"""
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
residuals = dim_actual - dim_output
X, Y = dim_actual, residuals
if order_residuals:
X = range(1, residuals.size(0)+1)
Y = residuals[residuals.argsort()]
ax.scatter(
X, Y,
color=dim_color,
label=group_name,
**subplot_kwargs,
)
# plot goodness of fit line
if not order_residuals:
m, b = self._lstsq_dim(dim_actual, residuals)
ax.plot(
dim_actual, m * dim_actual + b,
color=dim_color,
label=f"GoF {group_name}",
**gof_kwargs,
)
def context_fn(ax: plt.Axes, loader_label: str) -> None:
# compare residuals to y=0
ax.axhline(y=0, c="black", alpha=0.2)
ax.set_title(f"[{loader_label}] Prediction residuals")
ax.set_xlabel("actual")
ax.set_ylabel("residual")
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)
def plot_actual_output_residual_dist(
self,
row_size: int | float = 2,
col_size: int | float = 4,
norm_samples: bool = False,
combine_dims: bool = True,
figure_kwargs: SubplotsKwargs | None = None,
subplot_kwargs: dict | None = None,
gof_kwargs: dict | None = None,
) -> tuple[plt.Figure, AxesArray]:
"""
Plot residual distribution.
"""
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
def subplot_fn(
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
) -> None:
dim_color = colors[dim % len(colors)]
group_name = "norm" if norm_samples else f"$d_{dim}$"
dim_actual = actual[:, dim]
dim_output = output[:, dim]
N = dim_actual.size(0)
residuals = dim_actual - dim_output
_, _, patches = ax.hist(
residuals.abs(),
bins=int(np.sqrt(N)),
density=True,
alpha=0.3,
color=dim_color,
label=group_name,
**subplot_kwargs
)
mu = residuals.abs().mean().item()
ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$")
def context_fn(ax: plt.Axes, loader_label: str) -> None:
ax.set_title(f"[{loader_label}] Residual distribution")
ax.set_xlabel("actual")
ax.set_ylabel("residual (density)")
return self._plot_base(
subplot_fn, context_fn,
row_size=row_size,
col_size=col_size,
norm_samples=norm_samples,
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)