465 lines
14 KiB
Python
465 lines
14 KiB
Python
from functools import partial
|
|
from collections.abc import Callable
|
|
|
|
import numpy as np
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
from torch import Tensor
|
|
from torch.utils.data import DataLoader
|
|
|
|
from trainlib.trainer import Trainer
|
|
from trainlib.estimator import EstimatorKwargs
|
|
from trainlib.utils.type import AxesArray, SubplotsKwargs
|
|
|
|
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
|
|
type ContextFn = Callable[[plt.Axes, str], None]
|
|
|
|
|
|
class Plotter[B, K: EstimatorKwargs]:
|
|
"""
|
|
TODOs:
|
|
|
|
- best fit lines for plots and residuals (compare to ideal lines in each
|
|
case)
|
|
- show val options across columns; preview how val is changing across
|
|
natural training, and what the best will look (so plot like uniform
|
|
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
|
highlight the best one, even if that's not actually the single best
|
|
epoch)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
trainer: Trainer[..., K],
|
|
dataloaders: list[DataLoader],
|
|
batch_estimator_map: Callable[[B, Trainer], ...],
|
|
estimator_to_output_map: Callable[[K], ...],
|
|
dataloader_labels: list[str] | None = None,
|
|
) -> None:
|
|
self.trainer = trainer
|
|
self.dataloaders = dataloaders
|
|
self.dataloader_labels = (
|
|
dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
|
|
)
|
|
self.batch_estimator_map = batch_estimator_map
|
|
self.estimator_to_output_map = estimator_to_output_map
|
|
|
|
self._outputs: list[list[Tensor]] | None = None
|
|
self._metrics: list[list[dict[str, float]]] | None = None
|
|
|
|
self._batch_outputs_fn = partial(
|
|
self.trainer.get_batch_outputs,
|
|
batch_estimator_map=batch_estimator_map
|
|
)
|
|
self._batch_metrics_fn = partial(
|
|
self.trainer.get_batch_metrics,
|
|
batch_estimator_map=batch_estimator_map
|
|
)
|
|
|
|
self._data_tuples = None
|
|
|
|
@property
|
|
def outputs(self) -> list[list[Tensor]]:
|
|
if self._outputs is None:
|
|
self._outputs = [
|
|
list(map(self._batch_outputs_fn, loader))
|
|
for loader in self.dataloaders
|
|
]
|
|
return self._outputs
|
|
|
|
@property
|
|
def metrics(self) -> list[list[dict[str, float]]]:
|
|
if self._metrics is None:
|
|
self._metrics = [
|
|
list(map(self._batch_metrics_fn, loader))
|
|
for loader in self.dataloaders
|
|
]
|
|
return self._metrics
|
|
|
|
@property
|
|
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
|
|
"""
|
|
Produce data items; to be cached. Zip later with axes
|
|
"""
|
|
|
|
if self._data_tuples is not None:
|
|
return self._data_tuples
|
|
|
|
data_tuples = []
|
|
for i, loader in enumerate(self.dataloaders):
|
|
label = self.dataloader_labels[i]
|
|
|
|
batch = next(iter(loader))
|
|
est_kwargs = self.batch_estimator_map(batch, self.trainer)
|
|
actual = self.estimator_to_output_map(est_kwargs)
|
|
output = self._batch_outputs_fn(batch)
|
|
|
|
data_tuples.append((actual, output, label))
|
|
|
|
self._data_tuples = data_tuples
|
|
|
|
return self._data_tuples
|
|
|
|
def get_transposed_handles_labels(
|
|
self,
|
|
ax: plt.Axes
|
|
) -> tuple[list, list]:
|
|
# transpose legend layout for more natural view
|
|
handles, labels = ax.get_legend_handles_labels()
|
|
handles = handles[::2] + handles[1::2]
|
|
labels = labels[::2] + labels[1::2]
|
|
|
|
return handles, labels
|
|
|
|
def _lstsq_dim(
|
|
self,
|
|
dim_actual: Tensor,
|
|
dim_output: Tensor
|
|
) -> tuple[Tensor, Tensor]:
|
|
A = torch.stack(
|
|
[dim_actual, torch.ones_like(dim_actual)],
|
|
dim=1
|
|
)
|
|
m, b = torch.linalg.lstsq(A, dim_output).solution
|
|
|
|
return m, b
|
|
|
|
def _prepare_figure_kwargs(
|
|
self,
|
|
rows: int,
|
|
cols: int,
|
|
row_size: int | float = 2,
|
|
col_size: int | float = 4,
|
|
figure_kwargs: SubplotsKwargs | None = None,
|
|
) -> SubplotsKwargs:
|
|
"""
|
|
"""
|
|
|
|
default_figure_kwargs = {
|
|
"sharex": True,
|
|
"figsize": (col_size*cols, row_size*rows),
|
|
}
|
|
figure_kwargs = {
|
|
**default_figure_kwargs,
|
|
**(figure_kwargs or {}),
|
|
}
|
|
|
|
return figure_kwargs
|
|
|
|
def _prepare_subplot_kwargs(
|
|
self,
|
|
subplot_kwargs: dict | None = None,
|
|
) -> dict:
|
|
"""
|
|
"""
|
|
|
|
default_subplot_kwargs = {}
|
|
subplot_kwargs = {
|
|
**default_subplot_kwargs,
|
|
**(subplot_kwargs or {}),
|
|
}
|
|
|
|
return subplot_kwargs
|
|
|
|
def _prepare_gof_kwargs(
|
|
self,
|
|
gof_kwargs: dict | None = None,
|
|
) -> dict:
|
|
"""
|
|
"""
|
|
|
|
default_gof_kwargs = {
|
|
"alpha": 0.5
|
|
}
|
|
gof_kwargs = {
|
|
**default_gof_kwargs,
|
|
**(gof_kwargs or {}),
|
|
}
|
|
|
|
return gof_kwargs
|
|
|
|
def _create_subplots(
|
|
self,
|
|
rows: int,
|
|
cols: int,
|
|
row_size: int | float = 2,
|
|
col_size: int | float = 4,
|
|
figure_kwargs: SubplotsKwargs | None = None,
|
|
) -> tuple[plt.Figure, AxesArray]:
|
|
"""
|
|
"""
|
|
|
|
figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs(
|
|
rows,
|
|
cols,
|
|
row_size=row_size,
|
|
col_size=col_size,
|
|
figure_kwargs=figure_kwargs,
|
|
)
|
|
fig, axes = plt.subplots(
|
|
rows,
|
|
cols,
|
|
squeeze=False,
|
|
**figure_kwargs,
|
|
) # ty:ignore[no-matching-overload]
|
|
|
|
return fig, axes
|
|
|
|
def _plot_base(
|
|
self,
|
|
subplot_fn: SubplotFn,
|
|
context_fn: ContextFn,
|
|
row_size: int | float = 2,
|
|
col_size: int | float = 4,
|
|
norm_samples: bool = False,
|
|
combine_dims: bool = True,
|
|
figure_kwargs: SubplotsKwargs | None = None,
|
|
) -> tuple[plt.Figure, AxesArray]:
|
|
"""
|
|
Note: transform samples in dataloader definitions beforehand if you
|
|
want to change data
|
|
|
|
Parameters:
|
|
row_size:
|
|
col_size:
|
|
figure_kwargs:
|
|
subplot_kwargs:
|
|
gof_kwargs:
|
|
"""
|
|
|
|
ndims = self.data_tuples[0][0].size(-1)
|
|
fig, axes = self._create_subplots(
|
|
rows=len(self.dataloaders),
|
|
cols=1 if (norm_samples or combine_dims) else ndims,
|
|
row_size=row_size,
|
|
col_size=col_size,
|
|
figure_kwargs=figure_kwargs,
|
|
)
|
|
|
|
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
|
|
actual, output, loader_label = data_tuple
|
|
|
|
if norm_samples:
|
|
actual = actual.norm(dim=1, keepdim=True)
|
|
output = output.norm(dim=1, keepdim=True)
|
|
|
|
for dim in range(ndims):
|
|
ax = axes_row[0 if combine_dims else dim]
|
|
subplot_fn(ax, dim, actual, output)
|
|
|
|
add_plot_context = (
|
|
norm_samples # dim=0 implicit b/c we break
|
|
or not combine_dims # add to every subplot across grid
|
|
or combine_dims and dim == ndims-1 # wait for last dim
|
|
)
|
|
|
|
# always exec plot logic if not combining, o/w exec just once
|
|
if add_plot_context:
|
|
context_fn(ax, loader_label)
|
|
|
|
# transpose legend layout for more natural view
|
|
if norm_samples or not combine_dims:
|
|
ax.legend()
|
|
else:
|
|
handles, labels = self.get_transposed_handles_labels(ax)
|
|
ax.legend(handles, labels, ncols=2)
|
|
|
|
# break dimension loop if collapsed by norm
|
|
if norm_samples:
|
|
break
|
|
|
|
return fig, axes
|
|
|
|
def plot_actual_output(
|
|
self,
|
|
row_size: int | float = 2,
|
|
col_size: int | float = 4,
|
|
norm_samples: bool = False,
|
|
combine_dims: bool = True,
|
|
figure_kwargs: SubplotsKwargs | None = None,
|
|
subplot_kwargs: dict | None = None,
|
|
gof_kwargs: dict | None = None,
|
|
) -> tuple[plt.Figure, AxesArray]:
|
|
"""
|
|
Plot residual distribution.
|
|
"""
|
|
|
|
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
|
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
|
|
def subplot_fn(
|
|
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
|
) -> None:
|
|
dim_color = colors[dim % len(colors)]
|
|
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
|
|
|
dim_actual = actual[:, dim]
|
|
dim_output = output[:, dim]
|
|
|
|
ax.scatter(
|
|
dim_actual, dim_output,
|
|
color=dim_color,
|
|
label=group_name,
|
|
**subplot_kwargs,
|
|
)
|
|
|
|
# plot goodness of fit line
|
|
m, b = self._lstsq_dim(dim_actual, dim_output)
|
|
ax.plot(
|
|
dim_actual, m * dim_actual + b,
|
|
color=dim_color,
|
|
label=f"GoF {group_name}",
|
|
**gof_kwargs,
|
|
)
|
|
|
|
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
|
# plot perfect prediction reference line, y=x
|
|
ax.plot(
|
|
[0, 1], [0, 1],
|
|
transform=ax.transAxes,
|
|
c="black",
|
|
alpha=0.2,
|
|
)
|
|
|
|
ax.set_title(
|
|
f"[{loader_label}] True labels vs Predictions"
|
|
)
|
|
ax.set_xlabel("actual")
|
|
ax.set_ylabel("output")
|
|
|
|
return self._plot_base(
|
|
subplot_fn, context_fn,
|
|
row_size=row_size,
|
|
col_size=col_size,
|
|
norm_samples=norm_samples,
|
|
combine_dims=combine_dims,
|
|
figure_kwargs=figure_kwargs,
|
|
)
|
|
|
|
def plot_actual_output_residual(
|
|
self,
|
|
row_size: int | float = 2,
|
|
col_size: int | float = 4,
|
|
order_residuals: bool = False,
|
|
norm_samples: bool = False,
|
|
combine_dims: bool = True,
|
|
figure_kwargs: SubplotsKwargs | None = None,
|
|
subplot_kwargs: dict | None = None,
|
|
gof_kwargs: dict | None = None,
|
|
) -> tuple[plt.Figure, AxesArray]:
|
|
"""
|
|
Plot prediction residuals.
|
|
"""
|
|
|
|
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
|
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
|
|
def subplot_fn(
|
|
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
|
) -> None:
|
|
dim_color = colors[dim % len(colors)]
|
|
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
|
|
|
dim_actual = actual[:, dim]
|
|
dim_output = output[:, dim]
|
|
|
|
residuals = dim_actual - dim_output
|
|
X, Y = dim_actual, residuals
|
|
|
|
if order_residuals:
|
|
X = range(1, residuals.size(0)+1)
|
|
Y = residuals[residuals.argsort()]
|
|
|
|
ax.scatter(
|
|
X, Y,
|
|
color=dim_color,
|
|
label=group_name,
|
|
**subplot_kwargs,
|
|
)
|
|
|
|
# plot goodness of fit line
|
|
if not order_residuals:
|
|
m, b = self._lstsq_dim(dim_actual, residuals)
|
|
ax.plot(
|
|
dim_actual, m * dim_actual + b,
|
|
color=dim_color,
|
|
label=f"GoF {group_name}",
|
|
**gof_kwargs,
|
|
)
|
|
|
|
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
|
# compare residuals to y=0
|
|
ax.axhline(y=0, c="black", alpha=0.2)
|
|
|
|
ax.set_title(f"[{loader_label}] Prediction residuals")
|
|
ax.set_xlabel("actual")
|
|
ax.set_ylabel("residual")
|
|
|
|
return self._plot_base(
|
|
subplot_fn, context_fn,
|
|
row_size=row_size,
|
|
col_size=col_size,
|
|
norm_samples=norm_samples,
|
|
combine_dims=combine_dims,
|
|
figure_kwargs=figure_kwargs,
|
|
)
|
|
|
|
def plot_actual_output_residual_dist(
|
|
self,
|
|
row_size: int | float = 2,
|
|
col_size: int | float = 4,
|
|
norm_samples: bool = False,
|
|
combine_dims: bool = True,
|
|
figure_kwargs: SubplotsKwargs | None = None,
|
|
subplot_kwargs: dict | None = None,
|
|
gof_kwargs: dict | None = None,
|
|
) -> tuple[plt.Figure, AxesArray]:
|
|
"""
|
|
Plot residual distribution.
|
|
"""
|
|
|
|
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
|
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
|
|
def subplot_fn(
|
|
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
|
) -> None:
|
|
dim_color = colors[dim % len(colors)]
|
|
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
|
|
|
dim_actual = actual[:, dim]
|
|
dim_output = output[:, dim]
|
|
|
|
N = dim_actual.size(0)
|
|
residuals = dim_actual - dim_output
|
|
|
|
_, _, patches = ax.hist(
|
|
residuals.abs(),
|
|
bins=int(np.sqrt(N)),
|
|
density=True,
|
|
alpha=0.3,
|
|
color=dim_color,
|
|
label=group_name,
|
|
**subplot_kwargs
|
|
)
|
|
|
|
mu = residuals.abs().mean().item()
|
|
ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$")
|
|
|
|
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
|
ax.set_title(f"[{loader_label}] Residual distribution")
|
|
ax.set_xlabel("actual")
|
|
ax.set_ylabel("residual (density)")
|
|
|
|
return self._plot_base(
|
|
subplot_fn, context_fn,
|
|
row_size=row_size,
|
|
col_size=col_size,
|
|
norm_samples=norm_samples,
|
|
combine_dims=combine_dims,
|
|
figure_kwargs=figure_kwargs,
|
|
)
|