refactor Plotter to use a common plot base
This commit is contained in:
@@ -1,19 +1,134 @@
|
||||
from typing import Self
|
||||
"""
|
||||
def plot_actual_output_residual(
|
||||
self,
|
||||
order_residuals: bool = False,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
norm_samples: bool = False,
|
||||
combine_dims: bool = True,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
subplot_kwargs: dict | None = None,
|
||||
gof_kwargs: dict | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
Note: transform samples in dataloader definitions beforehand if you
|
||||
want to change data
|
||||
|
||||
Parameters:
|
||||
row_size:
|
||||
col_size:
|
||||
figure_kwargs:
|
||||
subplot_kwargs:
|
||||
gof_kwargs:
|
||||
|
||||
ndims = self.data_tuples[0][0].size(-1)
|
||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
fig, axes = self._create_subplots(
|
||||
rows=len(self.dataloaders),
|
||||
cols=1 if (norm_samples or combine_dims) else ndims,
|
||||
row_size=row_size,
|
||||
col_size=col_size,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||
|
||||
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||
actual, output, loader_label = data_tuple
|
||||
|
||||
if norm_samples:
|
||||
actual = actual.norm(dim=1, keepdim=True)
|
||||
output = output.norm(dim=1, keepdim=True)
|
||||
|
||||
for dim in range(ndims):
|
||||
ax = axes_row[0 if combine_dims else dim]
|
||||
dim_color = colors[dim % len(colors)]
|
||||
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||
|
||||
dim_actual = actual[:, dim]
|
||||
dim_output = output[:, dim]
|
||||
residuals = dim_actual - dim_output
|
||||
|
||||
X, Y = dim_actual, residuals
|
||||
if order_residuals:
|
||||
X = range(1, residuals.size(0)+1)
|
||||
Y = residuals[residuals.argsort()]
|
||||
|
||||
ax.scatter(
|
||||
X, Y,
|
||||
color=dim_color,
|
||||
label=group_name,
|
||||
**subplot_kwargs,
|
||||
)
|
||||
|
||||
# plot goodness of fit line
|
||||
m, b = self._lstsq_dim(dim_actual, dim_output)
|
||||
ax.plot(
|
||||
dim_actual,
|
||||
m * dim_actual + b,
|
||||
color=dim_color,
|
||||
label=f"GoF {group_name}",
|
||||
**gof_kwargs,
|
||||
)
|
||||
|
||||
add_plot_context = (
|
||||
norm_samples # dim=0 implicit b/c we break
|
||||
or not combine_dims # add to every subplot across grid
|
||||
or combine_dims and dim == ndims-1 # wait for last dim
|
||||
)
|
||||
|
||||
# always exec plot logic if not combining, o/w just once
|
||||
if add_plot_context:
|
||||
# compare residuals to y=0
|
||||
ax.axhline(y=0, c="black", alpha=0.2)
|
||||
|
||||
ax.set_title(f"[{loader_label}] Prediction residuals")
|
||||
ax.set_xlabel("actual")
|
||||
ax.set_ylabel("residual")
|
||||
|
||||
# transpose legend layout for more natural view
|
||||
if norm_samples or not combine_dims:
|
||||
ax.legend()
|
||||
else:
|
||||
handles, labels = self.get_transposed_handles_labels(ax)
|
||||
ax.legend(handles, labels, ncols=2)
|
||||
|
||||
# break dimension loop if collapsed by norm
|
||||
if norm_samples:
|
||||
break
|
||||
|
||||
return fig, axes
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from torch import Tensor
|
||||
from numpy.typing import NDArray
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainlib.trainer import Trainer
|
||||
from trainlib.estimator import EstimatorKwargs
|
||||
from trainlib.utils.type import AxesArray, SubplotsKwargs
|
||||
|
||||
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
|
||||
type ContextFn = Callable[[plt.Axes, str], None]
|
||||
|
||||
|
||||
class Plotter[B, K: EstimatorKwargs]:
|
||||
"""
|
||||
TODOs:
|
||||
|
||||
- best fit lines for plots and residuals (compare to ideal lines in each
|
||||
case)
|
||||
- show val options across columns; preview how val is changing across
|
||||
natural training, and what the best will look (so plot like uniform
|
||||
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
||||
highlight the best one, even if that's not actually the single best
|
||||
epoch)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trainer: Trainer[..., K],
|
||||
@@ -35,11 +150,11 @@ class Plotter[B, K: EstimatorKwargs]:
|
||||
|
||||
self._batch_outputs_fn = partial(
|
||||
self.trainer.get_batch_outputs,
|
||||
batch_estimator_map=self.batch_estimator_map
|
||||
batch_estimator_map=batch_estimator_map
|
||||
)
|
||||
self._batch_metrics_fn = partial(
|
||||
self.trainer.get_batch_metrics,
|
||||
batch_estimator_map=self.batch_estimator_map
|
||||
batch_estimator_map=batch_estimator_map
|
||||
)
|
||||
|
||||
self._data_tuples = None
|
||||
@@ -83,201 +198,369 @@ class Plotter[B, K: EstimatorKwargs]:
|
||||
data_tuples.append((actual, output, label))
|
||||
|
||||
self._data_tuples = data_tuples
|
||||
|
||||
return self._data_tuples
|
||||
|
||||
def _default_figure_kwargs(self, rows: int, cols: int) -> dict:
|
||||
return {
|
||||
"sharex": True,
|
||||
"figsize": (4*cols, 2*rows),
|
||||
"constrained_layout": True,
|
||||
}
|
||||
|
||||
def _default_subplot_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def _create_subplots(
|
||||
def get_transposed_handles_labels(
|
||||
self,
|
||||
**figure_kwargs: SubplotsKwargs,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
"""
|
||||
"""
|
||||
ax: plt.Axes
|
||||
) -> tuple[list, list]:
|
||||
# transpose legend layout for more natural view
|
||||
handles, labels = ax.get_legend_handles_labels()
|
||||
handles = handles[::2] + handles[1::2]
|
||||
labels = labels[::2] + labels[1::2]
|
||||
|
||||
rows, cols = len(self.dataloaders), 1
|
||||
figure_kwargs = {
|
||||
**self._default_figure_kwargs(rows, cols),
|
||||
**figure_kwargs,
|
||||
}
|
||||
fig, axes = plt.subplots(
|
||||
rows, cols,
|
||||
squeeze=False,
|
||||
**figure_kwargs
|
||||
return handles, labels
|
||||
|
||||
def _lstsq_dim(
|
||||
self,
|
||||
dim_actual: Tensor,
|
||||
dim_output: Tensor
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
A = torch.stack(
|
||||
[dim_actual, torch.ones_like(dim_actual)],
|
||||
dim=1
|
||||
)
|
||||
m, b = torch.linalg.lstsq(A, dim_output).solution
|
||||
|
||||
return fig, axes
|
||||
return m, b
|
||||
|
||||
def plot_actual_output_dim(
|
||||
def _prepare_figure_kwargs(
|
||||
self,
|
||||
figure_kwargs: dict | None = None,
|
||||
subplot_kwargs: dict | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
rows: int,
|
||||
cols: int,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
) -> SubplotsKwargs:
|
||||
"""
|
||||
Wrapper like this works fine, but it's smelly: we *don't* want @wraps,
|
||||
do this method doesn't actually have this signature at runtime (it has
|
||||
the dec wrapper's sig). I think the cleaner thing is to just have
|
||||
internal methods (_func) like the one below, and then the main method
|
||||
entry just pass that internal method through to the skeleton
|
||||
"""
|
||||
|
||||
figure_kwargs = figure_kwargs or {}
|
||||
default_figure_kwargs = {
|
||||
"sharex": True,
|
||||
"figsize": (col_size*cols, row_size*rows),
|
||||
}
|
||||
figure_kwargs = {
|
||||
**default_figure_kwargs,
|
||||
**(figure_kwargs or {}),
|
||||
}
|
||||
|
||||
return figure_kwargs
|
||||
|
||||
def _prepare_subplot_kwargs(
|
||||
self,
|
||||
subplot_kwargs: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
"""
|
||||
|
||||
default_subplot_kwargs = {}
|
||||
subplot_kwargs = {
|
||||
**self._default_subplot_kwargs(),
|
||||
**default_subplot_kwargs,
|
||||
**(subplot_kwargs or {}),
|
||||
}
|
||||
|
||||
fig, axes = self._create_subplots(**figure_kwargs)
|
||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||
actual, output, label = data_tuple
|
||||
return subplot_kwargs
|
||||
|
||||
def _prepare_gof_kwargs(
|
||||
self,
|
||||
gof_kwargs: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
"""
|
||||
|
||||
default_gof_kwargs = {
|
||||
"alpha": 0.5
|
||||
}
|
||||
gof_kwargs = {
|
||||
**default_gof_kwargs,
|
||||
**(gof_kwargs or {}),
|
||||
}
|
||||
|
||||
return gof_kwargs
|
||||
|
||||
def _create_subplots(
|
||||
self,
|
||||
rows: int,
|
||||
cols: int,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
"""
|
||||
"""
|
||||
|
||||
figure_kwargs: SubplotsKwargs = self._prepare_figure_kwargs(
|
||||
rows,
|
||||
cols,
|
||||
row_size=row_size,
|
||||
col_size=col_size,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
fig, axes = plt.subplots(
|
||||
rows,
|
||||
cols,
|
||||
squeeze=False,
|
||||
**figure_kwargs,
|
||||
) # ty:ignore[no-matching-overload]
|
||||
|
||||
return fig, axes
|
||||
|
||||
def _plot_base(
|
||||
self,
|
||||
subplot_fn: SubplotFn,
|
||||
context_fn: ContextFn,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
norm_samples: bool = False,
|
||||
combine_dims: bool = True,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
"""
|
||||
Note: transform samples in dataloader definitions beforehand if you
|
||||
want to change data
|
||||
|
||||
Parameters:
|
||||
row_size:
|
||||
col_size:
|
||||
figure_kwargs:
|
||||
subplot_kwargs:
|
||||
gof_kwargs:
|
||||
"""
|
||||
|
||||
ndims = self.data_tuples[0][0].size(-1)
|
||||
fig, axes = self._create_subplots(
|
||||
rows=len(self.dataloaders),
|
||||
cols=1 if (norm_samples or combine_dims) else ndims,
|
||||
row_size=row_size,
|
||||
col_size=col_size,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
|
||||
for axes_row, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||
actual, output, loader_label = data_tuple
|
||||
|
||||
if norm_samples:
|
||||
actual = actual.norm(dim=1, keepdim=True)
|
||||
output = output.norm(dim=1, keepdim=True)
|
||||
|
||||
for dim in range(ndims):
|
||||
ax = axes_row[0 if combine_dims else dim]
|
||||
|
||||
subplot_fn(ax, dim, actual, output)
|
||||
|
||||
add_plot_context = (
|
||||
norm_samples # dim=0 implicit b/c we break
|
||||
or not combine_dims # add to every subplot across grid
|
||||
or combine_dims and dim == ndims-1 # wait for last dim
|
||||
)
|
||||
|
||||
# always exec plot logic if not combining, o/w exec just once
|
||||
if add_plot_context:
|
||||
context_fn(ax, loader_label)
|
||||
|
||||
# transpose legend layout for more natural view
|
||||
if norm_samples or not combine_dims:
|
||||
ax.legend()
|
||||
else:
|
||||
handles, labels = self.get_transposed_handles_labels(ax)
|
||||
ax.legend(handles, labels, ncols=2)
|
||||
|
||||
# break dimension loop if collapsed by norm
|
||||
if norm_samples:
|
||||
break
|
||||
|
||||
return fig, axes
|
||||
|
||||
def plot_actual_output(
|
||||
self,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
norm_samples: bool = False,
|
||||
combine_dims: bool = True,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
subplot_kwargs: dict | None = None,
|
||||
gof_kwargs: dict | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
"""
|
||||
Plot residual distribution.
|
||||
"""
|
||||
|
||||
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
|
||||
def subplot_fn(
|
||||
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
||||
) -> None:
|
||||
dim_color = colors[dim % len(colors)]
|
||||
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||
|
||||
dim_actual = actual[:, dim]
|
||||
dim_output = output[:, dim]
|
||||
|
||||
ax.scatter(
|
||||
dim_actual, dim_output,
|
||||
color=dim_color,
|
||||
label=group_name,
|
||||
**subplot_kwargs,
|
||||
)
|
||||
|
||||
# plot goodness of fit line
|
||||
m, b = self._lstsq_dim(dim_actual, dim_output)
|
||||
ax.plot(
|
||||
dim_actual, m * dim_actual + b,
|
||||
color=dim_color,
|
||||
label=f"GoF {group_name}",
|
||||
**gof_kwargs,
|
||||
)
|
||||
|
||||
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
||||
# plot perfect prediction reference line, y=x
|
||||
ax.plot(
|
||||
[0, 1], [0, 1],
|
||||
transform=ax.transAxes,
|
||||
c="black",
|
||||
alpha=0.2
|
||||
alpha=0.2,
|
||||
)
|
||||
|
||||
for dim in range(actual.size(-1)):
|
||||
ax.scatter(
|
||||
actual[:, dim],
|
||||
output[:, dim],
|
||||
label=f"$d_{dim}$",
|
||||
**subplot_kwargs
|
||||
)
|
||||
|
||||
ax.set_title(f"[{label}] True labels vs Predictions (dim-wise)")
|
||||
ax.set_title(
|
||||
f"[{loader_label}] True labels vs Predictions"
|
||||
)
|
||||
ax.set_xlabel("actual")
|
||||
ax.set_ylabel("output")
|
||||
ax.legend()
|
||||
|
||||
return fig, axes
|
||||
return self._plot_base(
|
||||
subplot_fn, context_fn,
|
||||
row_size=row_size,
|
||||
col_size=col_size,
|
||||
norm_samples=norm_samples,
|
||||
combine_dims=combine_dims,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
|
||||
def plot_actual_output_residual_dim(
|
||||
def plot_actual_output_residual(
|
||||
self,
|
||||
figure_kwargs: dict | None = None,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
order_residuals: bool = False,
|
||||
norm_samples: bool = False,
|
||||
combine_dims: bool = True,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
subplot_kwargs: dict | None = None,
|
||||
gof_kwargs: dict | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
"""
|
||||
Plot prediction residuals.
|
||||
"""
|
||||
|
||||
figure_kwargs = figure_kwargs or {}
|
||||
subplot_kwargs = {
|
||||
**self._default_subplot_kwargs(),
|
||||
**(subplot_kwargs or {}),
|
||||
}
|
||||
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
|
||||
fig, axes = self._create_subplots(**figure_kwargs)
|
||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||
actual, output, label = data_tuple
|
||||
def subplot_fn(
|
||||
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
||||
) -> None:
|
||||
dim_color = colors[dim % len(colors)]
|
||||
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||
|
||||
dim_actual = actual[:, dim]
|
||||
dim_output = output[:, dim]
|
||||
|
||||
residuals = dim_actual - dim_output
|
||||
X, Y = dim_actual, residuals
|
||||
|
||||
if order_residuals:
|
||||
X = range(1, residuals.size(0)+1)
|
||||
Y = residuals[residuals.argsort()]
|
||||
|
||||
ax.scatter(
|
||||
X, Y,
|
||||
color=dim_color,
|
||||
label=group_name,
|
||||
**subplot_kwargs,
|
||||
)
|
||||
|
||||
# plot goodness of fit line
|
||||
if not order_residuals:
|
||||
m, b = self._lstsq_dim(dim_actual, residuals)
|
||||
ax.plot(
|
||||
dim_actual, m * dim_actual + b,
|
||||
color=dim_color,
|
||||
label=f"GoF {group_name}",
|
||||
**gof_kwargs,
|
||||
)
|
||||
|
||||
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
||||
# compare residuals to y=0
|
||||
ax.axhline(y=0, c="black", alpha=0.2)
|
||||
|
||||
for dim in range(actual.size(-1)):
|
||||
ax.scatter(
|
||||
actual[:, dim],
|
||||
actual[:, dim] - output[:, dim],
|
||||
label=f"$d_{dim}$",
|
||||
**subplot_kwargs
|
||||
)
|
||||
|
||||
ax.set_title(f"[{label}] Residuals (dim-wise)")
|
||||
ax.set_title(f"[{loader_label}] Prediction residuals")
|
||||
ax.set_xlabel("actual")
|
||||
ax.set_ylabel("residual")
|
||||
ax.legend()
|
||||
|
||||
return fig, axes
|
||||
|
||||
def plot_actual_output_ordered_residual_dim(
|
||||
self,
|
||||
figure_kwargs: dict | None = None,
|
||||
subplot_kwargs: dict | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
"""
|
||||
"""
|
||||
|
||||
figure_kwargs = figure_kwargs or {}
|
||||
subplot_kwargs = {
|
||||
**self._default_subplot_kwargs(),
|
||||
**(subplot_kwargs or {}),
|
||||
}
|
||||
|
||||
fig, axes = self._create_subplots(**figure_kwargs)
|
||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||
actual, output, label = data_tuple
|
||||
|
||||
# compare residuals to y=0
|
||||
ax.axhline(y=0, c="black", alpha=0.2)
|
||||
|
||||
for dim in range(actual.size(-1)):
|
||||
ax.scatter(
|
||||
actual[:, dim],
|
||||
actual[:, dim] - output[:, dim],
|
||||
label=f"$d_{dim}$",
|
||||
**subplot_kwargs
|
||||
)
|
||||
|
||||
ax.set_title(f"[{label}] Residuals (dim-wise)")
|
||||
ax.set_xlabel("actual")
|
||||
ax.set_ylabel("residual")
|
||||
ax.legend()
|
||||
|
||||
return fig, axes
|
||||
|
||||
return self._plot_base(
|
||||
subplot_fn, context_fn,
|
||||
row_size=row_size,
|
||||
col_size=col_size,
|
||||
norm_samples=norm_samples,
|
||||
combine_dims=combine_dims,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
|
||||
def plot_actual_output_residual_dist(
|
||||
self,
|
||||
figure_kwargs: dict | None = None,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
norm_samples: bool = False,
|
||||
combine_dims: bool = True,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
subplot_kwargs: dict | None = None,
|
||||
gof_kwargs: dict | None = None,
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
"""
|
||||
Plot residual distribution.
|
||||
"""
|
||||
|
||||
figure_kwargs = figure_kwargs or {}
|
||||
subplot_kwargs = {
|
||||
**self._default_subplot_kwargs(),
|
||||
**(subplot_kwargs or {}),
|
||||
}
|
||||
subplot_kwargs = self._prepare_subplot_kwargs(subplot_kwargs)
|
||||
gof_kwargs = self._prepare_gof_kwargs(gof_kwargs)
|
||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
|
||||
fig, axes = self._create_subplots(**figure_kwargs)
|
||||
for ax, data_tuple in zip(axes, self.data_tuples, strict=True):
|
||||
actual, output, label = data_tuple
|
||||
def subplot_fn(
|
||||
ax: plt.Axes, dim: int, actual: Tensor, output: Tensor
|
||||
) -> None:
|
||||
dim_color = colors[dim % len(colors)]
|
||||
group_name = "norm" if norm_samples else f"$d_{dim}$"
|
||||
|
||||
N = actual.size(0)
|
||||
for dim in range(actual.size(-1)):
|
||||
residuals = actual[:, dim] - output[:, dim]
|
||||
dim_actual = actual[:, dim]
|
||||
dim_output = output[:, dim]
|
||||
|
||||
_, _, patches = ax.hist(
|
||||
residuals.abs(),
|
||||
bins=int(np.sqrt(N)),
|
||||
density=True,
|
||||
alpha=0.2,
|
||||
label=f"$d_{dim}$",
|
||||
**subplot_kwargs
|
||||
)
|
||||
N = dim_actual.size(0)
|
||||
residuals = dim_actual - dim_output
|
||||
|
||||
# grab color used for hist and mirror in the v-line
|
||||
color = patches[0].get_facecolor()
|
||||
_, _, patches = ax.hist(
|
||||
residuals.abs(),
|
||||
bins=int(np.sqrt(N)),
|
||||
density=True,
|
||||
alpha=0.3,
|
||||
color=dim_color,
|
||||
label=group_name,
|
||||
**subplot_kwargs
|
||||
)
|
||||
|
||||
mu = residuals.abs().mean().item()
|
||||
ax.axvline(mu, linestyle=":", c=color, label=f"$\mu_{dim}$")
|
||||
mu = residuals.abs().mean().item()
|
||||
ax.axvline(mu, linestyle=":", c=dim_color, label=f"$\\mu_{dim}$")
|
||||
|
||||
ax.set_title(f"[{label}] Residual distribution (dim-wise)")
|
||||
def context_fn(ax: plt.Axes, loader_label: str) -> None:
|
||||
ax.set_title(f"[{loader_label}] Residual distribution")
|
||||
ax.set_xlabel("actual")
|
||||
ax.set_ylabel("residual")
|
||||
ax.set_ylabel("residual (density)")
|
||||
|
||||
# transpose legend layout for more natural view
|
||||
handles, labels = ax.get_legend_handles_labels()
|
||||
handles = handles[::2] + handles[1::2]
|
||||
labels = labels[::2] + labels[1::2]
|
||||
|
||||
ax.legend(handles, labels, ncols=2)
|
||||
|
||||
return fig, axes
|
||||
return self._plot_base(
|
||||
subplot_fn, context_fn,
|
||||
row_size=row_size,
|
||||
col_size=col_size,
|
||||
norm_samples=norm_samples,
|
||||
combine_dims=combine_dims,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
|
||||
@@ -1,46 +1,49 @@
|
||||
text.usetex : False
|
||||
mathtext.default : regular
|
||||
text.usetex : False
|
||||
mathtext.default : regular
|
||||
|
||||
font.family : sans-serif
|
||||
font.sans-serif : DejaVu Sans
|
||||
font.serif : DejaVu Serif
|
||||
font.cursive : DejaVu Sans
|
||||
mathtext.fontset : dejavuserif
|
||||
font.size : 9
|
||||
figure.titlesize : 9
|
||||
legend.fontsize : 9
|
||||
axes.titlesize : 9
|
||||
axes.labelsize : 9
|
||||
xtick.labelsize : 9
|
||||
ytick.labelsize : 9
|
||||
# testing to prevent component overlap/clipping
|
||||
figure.constrained_layout.use : True
|
||||
|
||||
#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb'])
|
||||
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
|
||||
font.family : sans-serif
|
||||
font.sans-serif : DejaVu Sans
|
||||
font.serif : DejaVu Serif
|
||||
font.cursive : DejaVu Sans
|
||||
mathtext.fontset : dejavuserif
|
||||
font.size : 9
|
||||
figure.titlesize : 9
|
||||
legend.fontsize : 9
|
||||
axes.titlesize : 9
|
||||
axes.labelsize : 9
|
||||
xtick.labelsize : 9
|
||||
ytick.labelsize : 9
|
||||
|
||||
image.interpolation : nearest
|
||||
image.resample : False
|
||||
image.composite_image : True
|
||||
# monobiome -d 0.45 -l 22
|
||||
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
|
||||
|
||||
axes.spines.left : True
|
||||
axes.spines.bottom : True
|
||||
axes.spines.top : False
|
||||
axes.spines.right : False
|
||||
image.interpolation : nearest
|
||||
image.resample : False
|
||||
image.composite_image : True
|
||||
|
||||
axes.linewidth : 1
|
||||
xtick.major.width : 1
|
||||
xtick.minor.width : 1
|
||||
ytick.major.width : 1
|
||||
ytick.minor.width : 1
|
||||
axes.spines.left : True
|
||||
axes.spines.bottom : True
|
||||
axes.spines.top : False
|
||||
axes.spines.right : False
|
||||
|
||||
lines.linewidth : 1
|
||||
lines.markersize : 1
|
||||
axes.linewidth : 1
|
||||
xtick.major.width : 1
|
||||
xtick.minor.width : 1
|
||||
ytick.major.width : 1
|
||||
ytick.minor.width : 1
|
||||
|
||||
savefig.dpi : 300
|
||||
savefig.format : svg
|
||||
savefig.bbox : tight
|
||||
savefig.pad_inches : 0.1
|
||||
lines.linewidth : 1
|
||||
lines.markersize : 1
|
||||
|
||||
svg.image_inline : True
|
||||
svg.fonttype : none
|
||||
savefig.dpi : 300
|
||||
savefig.format : svg
|
||||
savefig.bbox : tight
|
||||
savefig.pad_inches : 0.1
|
||||
|
||||
legend.frameon : False
|
||||
svg.image_inline : True
|
||||
svg.fonttype : none
|
||||
|
||||
legend.frameon : False
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, TypedDict
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
@@ -59,18 +59,11 @@ class SubplotsKwargs(TypedDict, total=False):
|
||||
sharex: bool | str
|
||||
sharey: bool | str
|
||||
squeeze: bool
|
||||
width_ratios: list[float]
|
||||
height_ratios: list[float]
|
||||
subplot_kw: dict
|
||||
gridspec_kw: dict
|
||||
width_ratios: Sequence[float]
|
||||
height_ratios: Sequence[float]
|
||||
subplot_kw: dict[str, ...]
|
||||
gridspec_kw: dict[str, ...]
|
||||
figsize: tuple[float, float]
|
||||
dpi: float
|
||||
layout: str
|
||||
|
||||
sharex: bool | Literal["none", "all", "row", "col"] = False,
|
||||
sharey: bool | Literal["none", "all", "row", "col"] = False,
|
||||
squeeze: bool = True,
|
||||
width_ratios: Sequence[float] | None = None,
|
||||
height_ratios: Sequence[float] | None = None,
|
||||
subplot_kw: dict[str, Any] | None = None,
|
||||
gridspec_kw: dict[str, Any] | None = None,
|
||||
constrained_layout: bool
|
||||
|
||||
Reference in New Issue
Block a user