refactor Plotter to use a common plot base
This commit is contained in:
@@ -1,19 +1,134 @@
|
|||||||
from typing import Self
|
"""
|
||||||
|
def plot_actual_output_residual(
|
||||||
|
self,
|
||||||
|
order_residuals: bool = False,
|
||||||
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
norm_samples: bool = False,
|
||||||
|
combine_dims: bool = True,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
|
subplot_kwargs: dict | None = None,
|
||||||
|
gof_kwargs: dict | None = None,
|
||||||
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
|
Note: transform samples in dataloader definitions beforehand if you
|
||||||
|
want to change data
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
row_size:
|
||||||
|
col_size:
|
||||||
|
figure_kwargs:
|
||||||
|
subplot_kwargs:
|
||||||
|
gof_kwargs:
|
||||||
|
|
||||||
|
ndims = self.data_tuples[0][0].size(-1)
|
||||||
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||||
|
fig, axes = self._create_subplots(
|
||||||
|
rows=len(self.dataloaders),
|
||||||
|
cols=1 if (norm_samples or combine_dims) else ndims,
|
||||||
|
row_size=row_size,
|
||||||
|
col_size=col_size,
|
||||||
|
figure_kwargs=figure_kwargs,
|
||||||
|
)
|
||||||
|
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||||
|
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||||
|
|
||||||
|
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||||
|
actual, output, loader_label = data_tuple
|
||||||
|
|
||||||
|
if norm_samples:
|
||||||
|
actual = actual.norm(dim=1, keepdim=True)
|
||||||
|
output = output.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
for dim in range(ndims):
|
||||||
|
ax = axes_row[0 if combine_dims else dim]
|
||||||
|
dim_color = colors[dim % len(colors)]
|
||||||
|
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||||
|
|
||||||
|
dim_actual = actual[:, dim]
|
||||||
|
dim_output = output[:, dim]
|
||||||
|
residuals = dim_actual - dim_output
|
||||||
|
|
||||||
|
X, Y = dim_actual, residuals
|
||||||
|
if order_residuals:
|
||||||
|
X = range(1, residuals.size(0)+1)
|
||||||
|
Y = residuals[residuals.argsort()]
|
||||||
|
|
||||||
|
ax.scatter(
|
||||||
|
X, Y,
|
||||||
|
color=dim_color,
|
||||||
|
label=group_name,
|
||||||
|
**subplot_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# plot goodness of fit line
|
||||||
|
m, b = self._lstsq_dim(dim_actual, dim_output)
|
||||||
|
ax.plot(
|
||||||
|
dim_actual,
|
||||||
|
m * dim_actual + b,
|
||||||
|
color=dim_color,
|
||||||
|
label=f"GoF {group_name}",
|
||||||
|
**gof_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
add_plot_context = (
|
||||||
|
norm_samples # dim=0 implicit b/c we break
|
||||||
|
or not combine_dims # add to every subplot across grid
|
||||||
|
or combine_dims and dim == ndims-1 # wait for last dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# always exec plot logic if not combining, o/w just once
|
||||||
|
if add_plot_context:
|
||||||
|
# compare residuals to y=0
|
||||||
|
ax.axhline(y=0, c="black", alpha=0.2)
|
||||||
|
|
||||||
|
ax.set_title(f"[{loader_label}] Prediction residuals")
|
||||||
|
ax.set_xlabel("actual")
|
||||||
|
ax.set_ylabel("residual")
|
||||||
|
|
||||||
|
# transpose legend layout for more natural view
|
||||||
|
if norm_samples or not combine_dims:
|
||||||
|
ax.legend()
|
||||||
|
else:
|
||||||
|
handles, labels = self.get_transposed_handles_labels(ax)
|
||||||
|
ax.legend(handles, labels, ncols=2)
|
||||||
|
|
||||||
|
# break dimension loop if collapsed by norm
|
||||||
|
if norm_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
return fig, axes
|
||||||
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from numpy.typing import NDArray
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from trainlib.trainer import Trainer
|
from trainlib.trainer import Trainer
|
||||||
from trainlib.estimator import EstimatorKwargs
|
from trainlib.estimator import EstimatorKwargs
|
||||||
from trainlib.utils.type import AxesArray, SubplotsKwargs
|
from trainlib.utils.type import AxesArray, SubplotsKwargs
|
||||||
|
|
||||||
|
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
|
||||||
|
type ContextFn = Callable[[plt.Axes, str], None]
|
||||||
|
|
||||||
|
|
||||||
class Plotter[B, K: EstimatorKwargs]:
|
class Plotter[B, K: EstimatorKwargs]:
|
||||||
|
"""
|
||||||
|
TODOs:
|
||||||
|
|
||||||
|
- best fit lines for plots and residuals (compare to ideal lines in each
|
||||||
|
case)
|
||||||
|
- show val options across columns; preview how val is changing across
|
||||||
|
natural training, and what the best will look (so plot like uniform
|
||||||
|
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
||||||
|
highlight the best one, even if that's not actually the single best
|
||||||
|
epoch)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
trainer: Trainer[..., K],
|
trainer: Trainer[..., K],
|
||||||
@@ -35,11 +150,11 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
self._batch_outputs_fn = partial(
|
self._batch_outputs_fn = partial(
|
||||||
self.trainer.get_batch_outputs,
|
self.trainer.get_batch_outputs,
|
||||||
batch_estimator_map=self.batch_estimator_map
|
batch_estimator_map=batch_estimator_map
|
||||||
)
|
)
|
||||||
self._batch_metrics_fn = partial(
|
self._batch_metrics_fn = partial(
|
||||||
self.trainer.get_batch_metrics,
|
self.trainer.get_batch_metrics,
|
||||||
batch_estimator_map=self.batch_estimator_map
|
batch_estimator_map=batch_estimator_map
|
||||||
)
|
)
|
||||||
|
|
||||||
self._data_tuples = None
|
self._data_tuples = None
|
||||||
@@ -83,201 +198,369 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
data_tuples.append((actual, output, label))
|
data_tuples.append((actual, output, label))
|
||||||
|
|
||||||
self._data_tuples = data_tuples
|
self._data_tuples = data_tuples
|
||||||
|
|
||||||
return self._data_tuples
|
return self._data_tuples
|
||||||
|
|
||||||
def _default_figure_kwargs(self, rows: int, cols: int) -> dict:
|
def get_transposed_handles_labels(
|
||||||
return {
|
|
||||||
"sharex": True,
|
|
||||||
"figsize": (4*cols, 2*rows),
|
|
||||||
"constrained_layout": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _default_subplot_kwargs(self) -> dict:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _create_subplots(
|
|
||||||
self,
|
self,
|
||||||
**figure_kwargs: SubplotsKwargs,
|
ax: plt.Axes
|
||||||
) -> tuple[plt.Figure, AxesArray]:
|
) -> tuple[list, list]:
|
||||||
"""
|
# transpose legend layout for more natural view
|
||||||
"""
|
handles, labels = ax.get_legend_handles_labels()
|
||||||
|
handles = handles[::2] + handles[1::2]
|
||||||
|
labels = labels[::2] + labels[1::2]
|
||||||
|
|
||||||
rows, cols = len(self.dataloaders), 1
|
return handles, labels
|
||||||
figure_kwargs = {
|
|
||||||
**self._default_figure_kwargs(rows, cols),
|
def _lstsq_dim(
|
||||||
**figure_kwargs,
|
self,
|
||||||
}
|
dim_actual: Tensor,
|
||||||
fig, axes = plt.subplots(
|
dim_output: Tensor
|
||||||
rows, cols,
|
) -> tuple[Tensor, Tensor]:
|
||||||
squeeze=False,
|
A = torch.stack(
|
||||||
**figure_kwargs
|
[dim_actual, torch.ones_like(dim_actual)],
|
||||||
|
dim=1
|
||||||
)
|
)
|
||||||
|
m, b = torch.linalg.lstsq(A, dim_output).solution
|
||||||
|
|
||||||
return fig, axes
|
return m, b
|
||||||
|
|
||||||
def plot_actual_output_dim(
|
def _prepare_figure_kwargs(
|
||||||
self,
|
self,
|
||||||
figure_kwargs: dict | None = None,
|
rows: int,
|
||||||
subplot_kwargs: dict | None = None,
|
cols: int,
|
||||||
) -> tuple[plt.Figure, AxesArray]:
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
|
) -> SubplotsKwargs:
|
||||||
"""
|
"""
|
||||||
Wrapper like this works fine, but it's smelly: we *don't* want @wraps,
|
|
||||||
do this method doesn't actually have this signature at runtime (it has
|
|
||||||
the dec wrapper's sig). I think the cleaner thing is to just have
|
|
||||||
internal methods (_func) like the one below, and then the main method
|
|
||||||
entry just pass that internal method through to the skeleton
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
figure_kwargs = figure_kwargs or {}
|
default_figure_kwargs = {
|
||||||
|
"sharex": True,
|
||||||
|
"figsize": (col_size*cols, row_size*rows),
|
||||||
|
}
|
||||||
|
figure_kwargs = {
|
||||||
|
**default_figure_kwargs,
|
||||||
|
**(figure_kwargs or {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return figure_kwargs
|
||||||
|
|
||||||
|
def _prepare_subplot_kwargs(
|
||||||
|
self,
|
||||||
|
subplot_kwargs: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_subplot_kwargs = {}
|
||||||
subplot_kwargs = {
|
subplot_kwargs = {
|
||||||
**self._default_subplot_kwargs(),
|
**default_subplot_kwargs,
|
||||||
**(subplot_kwargs or {}),
|
**(subplot_kwargs or {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
fig, axes = self._create_subplots(**figure_kwargs)
|
return subplot_kwargs
|
||||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
|
||||||
actual, output, label = data_tuple
|
|
||||||
|
|
||||||
|
def _prepare_gof_kwargs(
|
||||||
|
self,
|
||||||
|
gof_kwargs: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_gof_kwargs = {
|
||||||
|
"alpha": 0.5
|
||||||
|
}
|
||||||
|
gof_kwargs = {
|
||||||
|
**default_gof_kwargs,
|
||||||
|
**(gof_kwargs or {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return gof_kwargs
|
||||||
|
|
||||||
|
def _create_subplots(
|
||||||
|
self,
|
||||||
|
rows: int,
|
||||||
|
cols: int,
|
||||||
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs(
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
row_size=row_size,
|
||||||
|
col_size=col_size,
|
||||||
|
figure_kwargs=figure_kwargs,
|
||||||
|
)
|
||||||
|
fig, axes = plt.subplots(
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
|
squeeze=False,
|
||||||
|
**figure_kwargs,
|
||||||
|
) # ty:ignore[no-matching-overload]
|
||||||
|
|
||||||
|
return fig, axes
|
||||||
|
|
||||||
|
def _plot_base(
|
||||||
|
self,
|
||||||
|
subplot_fn: SubplotFn,
|
||||||
|
context_fn: ContextFn,
|
||||||
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
norm_samples: bool = False,
|
||||||
|
combine_dims: bool = True,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
|
"""
|
||||||
|
Note: transform samples in dataloader definitions beforehand if you
|
||||||
|
want to change data
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
row_size:
|
||||||
|
col_size:
|
||||||
|
figure_kwargs:
|
||||||
|
subplot_kwargs:
|
||||||
|
gof_kwargs:
|
||||||
|
"""
|
||||||
|
|
||||||
|
ndims = self.data_tuples[0][0].size(-1)
|
||||||
|
fig, axes = self._create_subplots(
|
||||||
|
rows=len(self.dataloaders),
|
||||||
|
cols=1 if (norm_samples or combine_dims) else ndims,
|
||||||
|
row_size=row_size,
|
||||||
|
col_size=col_size,
|
||||||
|
figure_kwargs=figure_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||||
|
actual, output, loader_label = data_tuple
|
||||||
|
|
||||||
|
if norm_samples:
|
||||||
|
actual = actual.norm(dim=1, keepdim=True)
|
||||||
|
output = output.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
for dim in range(ndims):
|
||||||
|
ax = axes_row[0 if combine_dims else dim]
|
||||||
|
|
||||||
|
subplot_fn(ax, dim, actual, output)
|
||||||
|
|
||||||
|
add_plot_context = (
|
||||||
|
norm_samples # dim=0 implicit b/c we break
|
||||||
|
or not combine_dims # add to every subplot across grid
|
||||||
|
or combine_dims and dim == ndims-1 # wait for last dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# always exec plot logic if not combining, o/w exec just once
|
||||||
|
if add_plot_context:
|
||||||
|
context_fn(ax, loader_label)
|
||||||
|
|
||||||
|
# transpose legend layout for more natural view
|
||||||
|
if norm_samples or not combine_dims:
|
||||||
|
ax.legend()
|
||||||
|
else:
|
||||||
|
handles, labels = self.get_transposed_handles_labels(ax)
|
||||||
|
ax.legend(handles, labels, ncols=2)
|
||||||
|
|
||||||
|
# break dimension loop if collapsed by norm
|
||||||
|
if norm_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
return fig, axes
|
||||||
|
|
||||||
|
def plot_actual_output(
|
||||||
|
self,
|
||||||
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
norm_samples: bool = False,
|
||||||
|
combine_dims: bool = True,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
|
subplot_kwargs: dict | None = None,
|
||||||
|
gof_kwargs: dict | None = None,
|
||||||
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
|
"""
|
||||||
|
Plot residual distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||||
|
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||||
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||||
|
|
||||||
|
def subplot_fn(
|
||||||
|
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
||||||
|
) -> None:
|
||||||
|
dim_color = colors[dim % len(colors)]
|
||||||
|
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||||
|
|
||||||
|
dim_actual = actual[:, dim]
|
||||||
|
dim_output = output[:, dim]
|
||||||
|
|
||||||
|
ax.scatter(
|
||||||
|
dim_actual, dim_output,
|
||||||
|
color=dim_color,
|
||||||
|
label=group_name,
|
||||||
|
**subplot_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# plot goodness of fit line
|
||||||
|
m, b = self._lstsq_dim(dim_actual, dim_output)
|
||||||
|
ax.plot(
|
||||||
|
dim_actual, m * dim_actual + b,
|
||||||
|
color=dim_color,
|
||||||
|
label=f"GoF {group_name}",
|
||||||
|
**gof_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
||||||
|
# plot perfect prediction reference line, y=x
|
||||||
ax.plot(
|
ax.plot(
|
||||||
[0, 1], [0, 1],
|
[0, 1], [0, 1],
|
||||||
transform=ax.transAxes,
|
transform=ax.transAxes,
|
||||||
c="black",
|
c="black",
|
||||||
alpha=0.2
|
alpha=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
for dim in range(actual.size(-1)):
|
ax.set_title(
|
||||||
ax.scatter(
|
f"[{loader_label}] True labels vs Predictions"
|
||||||
actual[:, dim],
|
)
|
||||||
output[:, dim],
|
|
||||||
label=f"$d_{dim}$",
|
|
||||||
**subplot_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_title(f"[{label}] True labels vs Predictions (dim-wise)")
|
|
||||||
ax.set_xlabel("actual")
|
ax.set_xlabel("actual")
|
||||||
ax.set_ylabel("output")
|
ax.set_ylabel("output")
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
return fig, axes
|
return self._plot_base(
|
||||||
|
subplot_fn, context_fn,
|
||||||
|
row_size=row_size,
|
||||||
|
col_size=col_size,
|
||||||
|
norm_samples=norm_samples,
|
||||||
|
combine_dims=combine_dims,
|
||||||
|
figure_kwargs=figure_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_actual_output_residual_dim(
|
def plot_actual_output_residual(
|
||||||
self,
|
self,
|
||||||
figure_kwargs: dict | None = None,
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
order_residuals: bool = False,
|
||||||
|
norm_samples: bool = False,
|
||||||
|
combine_dims: bool = True,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
subplot_kwargs: dict | None = None,
|
subplot_kwargs: dict | None = None,
|
||||||
|
gof_kwargs: dict | None = None,
|
||||||
) -> tuple[plt.Figure, AxesArray]:
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
"""
|
"""
|
||||||
|
Plot prediction residuals.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
figure_kwargs = figure_kwargs or {}
|
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||||
subplot_kwargs = {
|
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||||
**self._default_subplot_kwargs(),
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||||
**(subplot_kwargs or {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
fig, axes = self._create_subplots(**figure_kwargs)
|
def subplot_fn(
|
||||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
||||||
actual, output, label = data_tuple
|
) -> None:
|
||||||
|
dim_color = colors[dim % len(colors)]
|
||||||
|
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||||
|
|
||||||
|
dim_actual = actual[:, dim]
|
||||||
|
dim_output = output[:, dim]
|
||||||
|
|
||||||
|
residuals = dim_actual - dim_output
|
||||||
|
X, Y = dim_actual, residuals
|
||||||
|
|
||||||
|
if order_residuals:
|
||||||
|
X = range(1, residuals.size(0)+1)
|
||||||
|
Y = residuals[residuals.argsort()]
|
||||||
|
|
||||||
|
ax.scatter(
|
||||||
|
X, Y,
|
||||||
|
color=dim_color,
|
||||||
|
label=group_name,
|
||||||
|
**subplot_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# plot goodness of fit line
|
||||||
|
if not order_residuals:
|
||||||
|
m, b = self._lstsq_dim(dim_actual, residuals)
|
||||||
|
ax.plot(
|
||||||
|
dim_actual, m * dim_actual + b,
|
||||||
|
color=dim_color,
|
||||||
|
label=f"GoF {group_name}",
|
||||||
|
**gof_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
||||||
# compare residuals to y=0
|
# compare residuals to y=0
|
||||||
ax.axhline(y=0, c="black", alpha=0.2)
|
ax.axhline(y=0, c="black", alpha=0.2)
|
||||||
|
|
||||||
for dim in range(actual.size(-1)):
|
ax.set_title(f"[{loader_label}] Prediction residuals")
|
||||||
ax.scatter(
|
|
||||||
actual[:, dim],
|
|
||||||
actual[:, dim] - output[:, dim],
|
|
||||||
label=f"$d_{dim}$",
|
|
||||||
**subplot_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_title(f"[{label}] Residuals (dim-wise)")
|
|
||||||
ax.set_xlabel("actual")
|
ax.set_xlabel("actual")
|
||||||
ax.set_ylabel("residual")
|
ax.set_ylabel("residual")
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
return fig, axes
|
|
||||||
|
|
||||||
def plot_actual_output_ordered_residual_dim(
|
|
||||||
self,
|
|
||||||
figure_kwargs: dict | None = None,
|
|
||||||
subplot_kwargs: dict | None = None,
|
|
||||||
) -> tuple[plt.Figure, AxesArray]:
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
|
|
||||||
figure_kwargs = figure_kwargs or {}
|
|
||||||
subplot_kwargs = {
|
|
||||||
**self._default_subplot_kwargs(),
|
|
||||||
**(subplot_kwargs or {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
fig, axes = self._create_subplots(**figure_kwargs)
|
|
||||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
|
||||||
actual, output, label = data_tuple
|
|
||||||
|
|
||||||
# compare residuals to y=0
|
|
||||||
ax.axhline(y=0, c="black", alpha=0.2)
|
|
||||||
|
|
||||||
for dim in range(actual.size(-1)):
|
|
||||||
ax.scatter(
|
|
||||||
actual[:, dim],
|
|
||||||
actual[:, dim] - output[:, dim],
|
|
||||||
label=f"$d_{dim}$",
|
|
||||||
**subplot_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_title(f"[{label}] Residuals (dim-wise)")
|
|
||||||
ax.set_xlabel("actual")
|
|
||||||
ax.set_ylabel("residual")
|
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
return fig, axes
|
|
||||||
|
|
||||||
|
return self._plot_base(
|
||||||
|
subplot_fn, context_fn,
|
||||||
|
row_size=row_size,
|
||||||
|
col_size=col_size,
|
||||||
|
norm_samples=norm_samples,
|
||||||
|
combine_dims=combine_dims,
|
||||||
|
figure_kwargs=figure_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_actual_output_residual_dist(
|
def plot_actual_output_residual_dist(
|
||||||
self,
|
self,
|
||||||
figure_kwargs: dict | None = None,
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
norm_samples: bool = False,
|
||||||
|
combine_dims: bool = True,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
subplot_kwargs: dict | None = None,
|
subplot_kwargs: dict | None = None,
|
||||||
|
gof_kwargs: dict | None = None,
|
||||||
) -> tuple[plt.Figure, AxesArray]:
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
"""
|
"""
|
||||||
|
Plot residual distribution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
figure_kwargs = figure_kwargs or {}
|
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||||
subplot_kwargs = {
|
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||||
**self._default_subplot_kwargs(),
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||||
**(subplot_kwargs or {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
fig, axes = self._create_subplots(**figure_kwargs)
|
def subplot_fn(
|
||||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
||||||
actual, output, label = data_tuple
|
) -> None:
|
||||||
|
dim_color = colors[dim % len(colors)]
|
||||||
|
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||||
|
|
||||||
N = actual.size(0)
|
dim_actual = actual[:, dim]
|
||||||
for dim in range(actual.size(-1)):
|
dim_output = output[:, dim]
|
||||||
residuals = actual[:, dim] - output[:, dim]
|
|
||||||
|
|
||||||
_, _, patches = ax.hist(
|
N = dim_actual.size(0)
|
||||||
residuals.abs(),
|
residuals = dim_actual - dim_output
|
||||||
bins=int(np.sqrt(N)),
|
|
||||||
density=True,
|
|
||||||
alpha=0.2,
|
|
||||||
label=f"$d_{dim}$",
|
|
||||||
**subplot_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# grab color used for hist and mirror in the v-line
|
_, _, patches = ax.hist(
|
||||||
color = patches[0].get_facecolor()
|
residuals.abs(),
|
||||||
|
bins=int(np.sqrt(N)),
|
||||||
|
density=True,
|
||||||
|
alpha=0.3,
|
||||||
|
color=dim_color,
|
||||||
|
label=group_name,
|
||||||
|
**subplot_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
mu = residuals.abs().mean().item()
|
mu = residuals.abs().mean().item()
|
||||||
ax.axvline(mu, linestyle=":", c=color, label=f"$\mu_{dim}$")
|
ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$")
|
||||||
|
|
||||||
ax.set_title(f"[{label}] Residual distribution (dim-wise)")
|
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
||||||
|
ax.set_title(f"[{loader_label}] Residual distribution")
|
||||||
ax.set_xlabel("actual")
|
ax.set_xlabel("actual")
|
||||||
ax.set_ylabel("residual")
|
ax.set_ylabel("residual (density)")
|
||||||
|
|
||||||
# transpose legend layout for more natural view
|
return self._plot_base(
|
||||||
handles, labels = ax.get_legend_handles_labels()
|
subplot_fn, context_fn,
|
||||||
handles = handles[::2] + handles[1::2]
|
row_size=row_size,
|
||||||
labels = labels[::2] + labels[1::2]
|
col_size=col_size,
|
||||||
|
norm_samples=norm_samples,
|
||||||
ax.legend(handles, labels, ncols=2)
|
combine_dims=combine_dims,
|
||||||
|
figure_kwargs=figure_kwargs,
|
||||||
return fig, axes
|
)
|
||||||
|
|||||||
@@ -1,46 +1,49 @@
|
|||||||
text.usetex : False
|
text.usetex : False
|
||||||
mathtext.default : regular
|
mathtext.default : regular
|
||||||
|
|
||||||
font.family : sans-serif
|
# testing to prevent component overlap/clipping
|
||||||
font.sans-serif : DejaVu Sans
|
figure.constrained_layout.use : True
|
||||||
font.serif : DejaVu Serif
|
|
||||||
font.cursive : DejaVu Sans
|
|
||||||
mathtext.fontset : dejavuserif
|
|
||||||
font.size : 9
|
|
||||||
figure.titlesize : 9
|
|
||||||
legend.fontsize : 9
|
|
||||||
axes.titlesize : 9
|
|
||||||
axes.labelsize : 9
|
|
||||||
xtick.labelsize : 9
|
|
||||||
ytick.labelsize : 9
|
|
||||||
|
|
||||||
#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb'])
|
font.family : sans-serif
|
||||||
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
|
font.sans-serif : DejaVu Sans
|
||||||
|
font.serif : DejaVu Serif
|
||||||
|
font.cursive : DejaVu Sans
|
||||||
|
mathtext.fontset : dejavuserif
|
||||||
|
font.size : 9
|
||||||
|
figure.titlesize : 9
|
||||||
|
legend.fontsize : 9
|
||||||
|
axes.titlesize : 9
|
||||||
|
axes.labelsize : 9
|
||||||
|
xtick.labelsize : 9
|
||||||
|
ytick.labelsize : 9
|
||||||
|
|
||||||
image.interpolation : nearest
|
# monobiome -d 0.45 -l 22
|
||||||
image.resample : False
|
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
|
||||||
image.composite_image : True
|
|
||||||
|
|
||||||
axes.spines.left : True
|
image.interpolation : nearest
|
||||||
axes.spines.bottom : True
|
image.resample : False
|
||||||
axes.spines.top : False
|
image.composite_image : True
|
||||||
axes.spines.right : False
|
|
||||||
|
|
||||||
axes.linewidth : 1
|
axes.spines.left : True
|
||||||
xtick.major.width : 1
|
axes.spines.bottom : True
|
||||||
xtick.minor.width : 1
|
axes.spines.top : False
|
||||||
ytick.major.width : 1
|
axes.spines.right : False
|
||||||
ytick.minor.width : 1
|
|
||||||
|
|
||||||
lines.linewidth : 1
|
axes.linewidth : 1
|
||||||
lines.markersize : 1
|
xtick.major.width : 1
|
||||||
|
xtick.minor.width : 1
|
||||||
|
ytick.major.width : 1
|
||||||
|
ytick.minor.width : 1
|
||||||
|
|
||||||
savefig.dpi : 300
|
lines.linewidth : 1
|
||||||
savefig.format : svg
|
lines.markersize : 1
|
||||||
savefig.bbox : tight
|
|
||||||
savefig.pad_inches : 0.1
|
|
||||||
|
|
||||||
svg.image_inline : True
|
savefig.dpi : 300
|
||||||
svg.fonttype : none
|
savefig.format : svg
|
||||||
|
savefig.bbox : tight
|
||||||
|
savefig.pad_inches : 0.1
|
||||||
|
|
||||||
legend.frameon : False
|
svg.image_inline : True
|
||||||
|
svg.fonttype : none
|
||||||
|
|
||||||
|
legend.frameon : False
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Any, TypedDict
|
from typing import Any, TypedDict
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@@ -59,18 +59,11 @@ class SubplotsKwargs(TypedDict, total=False):
|
|||||||
sharex: bool | str
|
sharex: bool | str
|
||||||
sharey: bool | str
|
sharey: bool | str
|
||||||
squeeze: bool
|
squeeze: bool
|
||||||
width_ratios: list[float]
|
width_ratios: Sequence[float]
|
||||||
height_ratios: list[float]
|
height_ratios: Sequence[float]
|
||||||
subplot_kw: dict
|
subplot_kw: dict[str, ...]
|
||||||
gridspec_kw: dict
|
gridspec_kw: dict[str, ...]
|
||||||
figsize: tuple[float, float]
|
figsize: tuple[float, float]
|
||||||
dpi: float
|
dpi: float
|
||||||
layout: str
|
layout: str
|
||||||
|
constrained_layout: bool
|
||||||
sharex: bool | Literal["none", "all", "row", "col"] = False,
|
|
||||||
sharey: bool | Literal["none", "all", "row", "col"] = False,
|
|
||||||
squeeze: bool = True,
|
|
||||||
width_ratios: Sequence[float] | None = None,
|
|
||||||
height_ratios: Sequence[float] | None = None,
|
|
||||||
subplot_kw: dict[str, Any] | None = None,
|
|
||||||
gridspec_kw: dict[str, Any] | None = None,
|
|
||||||
|
|||||||
Reference in New Issue
Block a user