refactor Trainer (large), improve dataset/est middleware

This commit is contained in:
2026-03-22 00:11:58 -07:00
parent 85d176862e
commit a395a08d5c
6 changed files with 300 additions and 288 deletions

102
trainlib/dataloader.py Normal file
View File

@@ -0,0 +1,102 @@
"""
This class took me a long time to really settle into. It's a connector, and it
feels redundant in many ways, so I've nearly deleted it several times while
talking through the design. But in total, I think it serves a clear purpose.
Reasons:
- Need a typed dataloader, even if I know the type of my attached transform
- Need a new scope that uses the same base dataset without interfering with the
transform attribute; a design that sets or relies on that is subject to
conflict
- Why not just use vanilla DataLoaders?
I'd like to, but the two reasons above make it clear why this is challenging:
I don't get static checks on the structures returned during iteration, and
while you can control ad hoc data transforms via dataset ``post_transforms``,
things can get messy if you need to do that for many transforms using the
same dataset (without copying). Simplest way around this is just a new scope
with the same underlying dataset instance and a transform wrapper around the
iterator; no interference with object attributes.
This is really just meant as the minimum viable logic needed to accomplish the
above - it's a very lightweight wrapper on the base ``DataLoader`` object.
There's an explicit type upper bound ``Kw: EstimatorKwargs``, but it is
otherwise a completely general transform over dataloader batches, highlighting
that it's *mostly* here to place nice with type checks.
"""
from typing import Unpack
from collections.abc import Iterator
from torch.utils.data import DataLoader
from trainlib.dataset import BatchedDataset
from trainlib.estimator import EstimatorKwargs
from trainlib.utils.type import LoaderKwargs
class EstimatorDataLoader[B, Kw: EstimatorKwargs]:
"""
Data loaders for estimators.
This class exists to connect batched data from datasets to the expected
representation for estimator methods. Datasets may be developed
independently from a given model structures, and models should be trainable
under any such data. We need a way to ensure the batched groups of items we
get from dataloaders match on a type level, i.e., can be reshaped into the
expected ``Kw`` signature.
Note: batch structure ``B`` cannot be directly inferred from type variables
exposed by ``BatchedDatasets`` (namely ``R`` and ``I``). What's returned by
a data loader wrapping any such dataset can be arbitrary (depending on the
``collate_fn``), with default behavior being fairly consistent under nested
collections but challenging to accurately type.
.. todo::
To log (have changed for Trainer):
- New compact eval pipeline for train/val/auxiliary dataloaders.
Somewhat annoying logic, but handled consistently
- Convergence tracker will dynamically use training loss (early
stopping) when a validation set isn't provided. Same mechanics for
stagnant epochs (although early stopping is generally a little more
nuanced, having a rate-based stopper, b/c train loss generally quite
monotonic). So that's to be updated, plus room for possible model
selection strategies later.
- Logging happens at each batch, but we append to an epoch-indexed list
and later average. There was a bug in the last round of testing that
I didn't pick up where I was just overwriting summaries using the
last seen batch.
- Reworked general dataset/dataloader handling for main train loop, now
accepting objects of this class to bridge estimator and dataset
communication. This cleans up the batch mapping model.
- TODO: implement a version of this that canonically works with the
device passing plus EstimatorKwargs input; this is the last fuzzy bit
I think.
"""
def __init__(
self,
dataset: BatchedDataset,
**dataloader_kwargs: Unpack[LoaderKwargs],
) -> None:
self._dataloader = DataLoader(dataset, **dataloader_kwargs)
def batch_to_est_kwargs(self, batch_data: B) -> Kw:
"""
.. note::
Even if we have a concrete shape for the output kwarg dict for base
estimators (requiring a tensor "inputs" attribute), we don't
presuppose how a given batch object will map into this dict
structure.
return EstimatorKwargs({"inputs":0})
"""
raise NotImplementedError
def __iter__(self) -> Iterator[Kw]:
return map(self.batch_to_est_kwargs, self._dataloader)

View File

View File

@@ -1,21 +1,21 @@
from functools import partial
from typing import Any
from collections.abc import Callable
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import Tensor
from torch.utils.data import DataLoader
from trainlib.trainer import Trainer
from trainlib.estimator import EstimatorKwargs
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.type import AxesArray, SubplotsKwargs
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
type ContextFn = Callable[[plt.Axes, str], None]
class Plotter[B, K: EstimatorKwargs]:
class Plotter[Kw: EstimatorKwargs]:
"""
TODOs:
@@ -30,10 +30,9 @@ class Plotter[B, K: EstimatorKwargs]:
def __init__(
self,
trainer: Trainer[..., K],
dataloaders: list[DataLoader],
batch_estimator_map: Callable[[B, Trainer], ...],
estimator_to_output_map: Callable[[K], ...],
trainer: Trainer[Any, Kw],
dataloaders: list[EstimatorDataLoader[Any, Kw]],
kw_to_actual: Callable[[Kw], Tensor],
dataloader_labels: list[str] | None = None,
) -> None:
self.trainer = trainer
@@ -41,47 +40,21 @@ class Plotter[B, K: EstimatorKwargs]:
self.dataloader_labels = (
dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
)
self.batch_estimator_map = batch_estimator_map
self.estimator_to_output_map = estimator_to_output_map
self.kw_to_actual = kw_to_actual
self._outputs: list[list[Tensor]] | None = None
self._metrics: list[list[dict[str, float]]] | None = None
self._batch_outputs_fn = partial(
self.trainer.get_batch_outputs,
batch_estimator_map=batch_estimator_map
)
self._batch_metrics_fn = partial(
self.trainer.get_batch_metrics,
batch_estimator_map=batch_estimator_map
)
self._data_tuples = None
@property
def outputs(self) -> list[list[Tensor]]:
if self._outputs is None:
self._outputs = [
list(map(self._batch_outputs_fn, loader))
for loader in self.dataloaders
]
return self._outputs
@property
def metrics(self) -> list[list[dict[str, float]]]:
if self._metrics is None:
self._metrics = [
list(map(self._batch_metrics_fn, loader))
for loader in self.dataloaders
]
return self._metrics
@property
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
"""
Produce data items; to be cached. Zip later with axes
"""
self.trainer.estimator.eval()
if self._data_tuples is not None:
return self._data_tuples
@@ -89,10 +62,14 @@ class Plotter[B, K: EstimatorKwargs]:
for i, loader in enumerate(self.dataloaders):
label = self.dataloader_labels[i]
batch = next(iter(loader))
est_kwargs = self.batch_estimator_map(batch, self.trainer)
actual = self.estimator_to_output_map(est_kwargs)
output = self._batch_outputs_fn(batch)
actual = torch.cat([
self.kw_to_actual(batch_kwargs).detach().cpu()
for batch_kwargs in loader
])
output = torch.cat([
self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
for batch_kwargs in loader
])
data_tuples.append((actual, output, label))
@@ -219,6 +196,14 @@ class Plotter[B, K: EstimatorKwargs]:
Note: transform samples in dataloader definitions beforehand if you
want to change data
.. todo::
Merge in logic from general diagnostics, allowing collapse from
either dim and transposing.
Later: multi-trial error bars, or at least the ability to pass that
downstream
Parameters:
row_size:
col_size:
@@ -462,3 +447,72 @@ class Plotter[B, K: EstimatorKwargs]:
combine_dims=combine_dims,
figure_kwargs=figure_kwargs,
)
def estimator_diagnostic(
self,
row_size: int | float = 2,
col_size: int | float = 4,
session_name: str | None = None,
combine_groups: bool = False,
combine_metrics: bool = False,
transpose_layout: bool = False,
figure_kwargs: SubplotsKwargs | None = None,
):
session_map = self.trainer._event_log
session_name = session_name or next(iter(session_map))
groups = session_map[session_name]
num_metrics = len(groups[next(iter(groups))])
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
rows = 1 if combine_groups else len(groups)
cols = 1 if combine_metrics else num_metrics
if transpose_layout:
rows, cols = cols, rows
fig, axes = self._create_subplots(
rows=rows,
cols=cols,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
if transpose_layout:
axes = axes.T
for i, group_name in enumerate(groups):
axes_row = axes[0 if combine_groups else i]
group_metrics = groups[group_name]
for j, metric_name in enumerate(group_metrics):
ax = axes_row[0 if combine_metrics else j]
metric_dict = group_metrics[metric_name]
metric_data = np.array([
(k, np.mean(v)) for k, v in metric_dict.items()
])
if combine_groups and combine_metrics:
label = f"{group_name}-{metric_name}"
title_prefix = "all"
elif combine_groups:
label = group_name
title_prefix = metric_name
# elif combine_metrics:
else:
label = metric_name
title_prefix = group_name
# else:
# label = ""
# title_prefix = f"{group_name},{metric_name}"
ax.plot(
metric_data[:, 0],
metric_data[:, 1],
label=label,
# color=colors[j],
)
ax.set_title(f"[{title_prefix}] Metrics over epochs")
ax.set_xlabel("epoch", fontstyle='italic')
ax.set_ylabel("value", fontstyle='italic')
ax.legend()

View File

@@ -1,5 +1,15 @@
"""
Core interface for training ``Estimators`` with ``Datasets``
.. admonition:: Design of preview ``get_dataloaders()``
Note how much this method is doing, and the positivity in letting that be
more explicit elsewhere. The assignment of transforms to datasets before
wrapping as loaders is chief among these items, alongside the balancing and
splitting; I think those are hamfisted here to make it work with the old
setup, but I generally it's not consistent with the name "get dataloaders"
(i.e., and also balance and split and set transforms)
"""
import os
@@ -7,48 +17,41 @@ import time
import logging
from io import BytesIO
from copy import deepcopy
from typing import Any, Self
from typing import Any
from pathlib import Path
from collections import defaultdict
from collections.abc import Callable
import torch
from tqdm import tqdm
from torch import cuda, Tensor
from torch.optim import Optimizer
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from trainlib.dataset import BatchedDataset
from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.transform import Transform
from trainlib.utils.type import (
SplitKwargs,
LoaderKwargs,
BalanceKwargs,
)
from trainlib.utils.map import nested_defaultdict
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.module import ModelWrapper
logger: logging.Logger = logging.getLogger(__name__)
class Trainer[I, K: EstimatorKwargs]:
class Trainer[I, Kw: EstimatorKwargs]:
"""
Training interface for optimizing parameters of ``Estimators`` with
``Datasets``.
This class is generic to a dataset item type ``I`` and an estimator kwarg
type ``K``. These are the two primary components ``Trainer`` objects need
type ``Kw``. These are the two primary components ``Trainer`` objects need
to coordinate: they ultimately rely on a provided map to ensure data items
(type ``I``) from a dataset are appropriately routed as inputs to key
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
of type ``K``.
of type ``Kw``.
"""
def __init__(
self,
estimator: Estimator[K],
estimator: Estimator[Kw],
device: str | None = None,
chkpt_dir: str = "chkpt/",
tblog_dir: str = "tblog/",
@@ -93,6 +96,7 @@ class Trainer[I, K: EstimatorKwargs]:
self.estimator = estimator
self.estimator.to(self.device)
self._event_log = nested_defaultdict(4, list)
self.chkpt_dir = Path(chkpt_dir).resolve()
self.tblog_dir = Path(tblog_dir).resolve()
@@ -105,56 +109,28 @@ class Trainer[I, K: EstimatorKwargs]:
"""
self._epoch: int = 1
self._summary = defaultdict(lambda: defaultdict(dict))
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
self._summary = defaultdict(lambda: defaultdict(list))
self._val_loss = float("inf")
self._best_val_loss = float("inf")
self._conv_loss = float("inf")
self._best_conv_loss = float("inf")
self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch(
self,
train_loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
loader: EstimatorDataLoader[Any, Kw],
optimizers: tuple[Optimizer, ...],
max_grad_norm: float | None = None,
) -> list[float]:
"""
Train the estimator for a single epoch.
.. admonition:: On summary writers
Estimators can have several optimizers, and therefore can emit
several losses. This is a fairly unique case, but it's needed when
we want to optimize particular parameters in a particular order
(as in multi-model architectures, e.g., GANs). Point being: we
always iterate over optimizers/losses, even in the common case
where there's just a single value, and we index collections across
batches accordingly.
A few of the trackers, with the same size as the number of
optimizers:
- ``train_loss_sums``: tracks loss sums across all batches for the
epoch, used to update the loop preview text after each batch with
the current average loss
- ``train_loss_items``: collects current batch losses, recorded by
the TB writer
If there are ``M`` optimizers/losses, we log ``M`` loss terms to
the TB writer after each *batch* (not epoch). We could aggregate at
an epoch level, but parameter updates take place after each batch,
so large model changes can occur over the course of an epoch
(whereas the model remains the same over the course batch evals).
"""
loss_sums = []
self.estimator.train()
with tqdm(train_loader, unit="batch") as batches:
for i, batch_data in enumerate(batches):
est_kwargs = batch_estimator_map(batch_data, self)
losses = self.estimator.loss(**est_kwargs)
with tqdm(loader, unit="batch") as batches:
for i, batch_kwargs in enumerate(batches):
losses = self.estimator.loss(**batch_kwargs)
for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True)
@@ -186,9 +162,8 @@ class Trainer[I, K: EstimatorKwargs]:
def _eval_epoch(
self,
loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
loader_label: str,
loader: EstimatorDataLoader[Any, Kw],
label: str
) -> list[float]:
"""
Perform and record validation scores for a single epoch.
@@ -217,22 +192,21 @@ class Trainer[I, K: EstimatorKwargs]:
loss_sums = []
self.estimator.eval()
with tqdm(loader, unit="batch") as batches:
for i, batch_data in enumerate(batches):
est_kwargs = batch_estimator_map(batch_data, self)
losses = self.estimator.loss(**est_kwargs)
for i, batch_kwargs in enumerate(batches):
losses = self.estimator.loss(**batch_kwargs)
# one-time logging
if self._epoch == 0:
self._writer.add_graph(
ModelWrapper(self.estimator), est_kwargs
ModelWrapper(self.estimator), batch_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=loader_label,
**est_kwargs
group=label,
**batch_kwargs
)
loss_items = []
@@ -250,21 +224,21 @@ class Trainer[I, K: EstimatorKwargs]:
# log individual loss terms after each batch
for o_idx, loss_item in enumerate(loss_items):
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
self._log_event(label, f"loss_{o_idx}", loss_item)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**est_kwargs)
estimator_metrics = self.estimator.metrics(**batch_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._log_event(loader_label, metric_name, metric_value)
self._log_event(label, metric_name, metric_value)
return loss_sums
def _eval_loaders(
self,
loaders: list[DataLoader],
batch_estimator_map: Callable[[I, Self], K],
loader_labels: list[str],
) -> dict[str, list[float]]:
train_loader: EstimatorDataLoader[Any, Kw],
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
) -> tuple[list[float], list[float] | None, *list[float]]:
"""
Evaluate estimator over each provided dataloader.
@@ -280,40 +254,50 @@ class Trainer[I, K: EstimatorKwargs]:
information (just aggregated losses are provided here).
"""
return {
label: self._eval_epoch(loader, batch_estimator_map, label)
for loader, label in zip(loaders, loader_labels, strict=True)
}
train_loss = self._eval_epoch(train_loader, "train")
val_loss = self._eval_epoch(val_loader, "val") if val_loader else None
def train[B](
aux_loaders = aux_loaders or []
aux_losses = [
self._eval_epoch(aux_loader, f"aux{i}")
for i, aux_loader in enumerate(aux_loaders)
]
return train_loss, val_loss, *aux_losses
def train(
self,
dataset: BatchedDataset[..., ..., I],
batch_estimator_map: Callable[[B, Self], K],
train_loader: EstimatorDataLoader[Any, Kw],
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
*,
lr: float = 1e-3,
eps: float = 1e-8,
max_grad_norm: float | None = None,
max_epochs: int = 10,
stop_after_epochs: int = 5,
batch_size: int = 256,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
summarize_every: int = 1,
chkpt_every: int = 1,
resume_latest: bool = False,
session_name: str | None = None,
summary_writer: SummaryWriter | None = None,
aux_loaders: list[DataLoader] | None = None,
aux_loader_labels: list[str] | None = None,
) -> Estimator:
"""
TODO: consider making the dataloader ``collate_fn`` an explicit
parameter with a type signature that reflects ``B``, connecting the
``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader``
in-house to allow a generic around ``B``
.. todo::
- consider making the dataloader ``collate_fn`` an explicit
parameter with a type signature that reflects ``B``, connecting
the ``batch_estimator_map`` somewhere. Might also re-type a
``DataLoader`` in-house to allow a generic around ``B``
- Rework the validation specification. Accept something like a
"validate_with" parameter, or perhaps just move entirely to
accepting a dataloader list, label list. You might then also need
a "train_with," and you could set up sensible defaults so you
basically have the same interaction as now. The "problem" is you
always need a train set, and there's some clearly dependent logic
on a val set, but you don't *need* val, so this should be
slightly reworked (and the more general, *probably* the better in
this case, given I want to plug into the Plotter with possibly
several purely eval sets over the model training lifetime).
Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The
@@ -355,7 +339,7 @@ class Trainer[I, K: EstimatorKwargs]:
This function should map from batches - which *may* be item
shaped, i.e., have an ``I`` skeleton, even if stacked items may be
different on the inside - into estimator keyword arguments (type
``K``). Collation behavior from a DataLoader (which can be
``Kw``). Collation behavior from a DataLoader (which can be
customized) doesn't consistently yield a known type shape, however,
so it's not appropriate to use ``I`` as the callable param type.
@@ -416,18 +400,13 @@ class Trainer[I, K: EstimatorKwargs]:
dataset
val_split_frac: fraction of dataset to use for validation
chkpt_every: how often model checkpoints should be saved
resume_latest: resume training from the latest available checkpoint
in the `chkpt_dir`
"""
logger.info("> Begin train loop:")
logger.info(f"| > {lr=}")
logger.info(f"| > {eps=}")
logger.info(f"| > {max_epochs=}")
logger.info(f"| > {batch_size=}")
logger.info(f"| > {val_frac=}")
logger.info(f"| > {chkpt_every=}")
logger.info(f"| > {resume_latest=}")
logger.info(f"| > with device: {self.device}")
logger.info(f"| > core count: {os.cpu_count()}")
@@ -435,25 +414,9 @@ class Trainer[I, K: EstimatorKwargs]:
tblog_path = Path(self.tblog_dir, self._session_name)
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
aux_loaders = aux_loaders or []
aux_loader_labels = aux_loader_labels or []
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
train_loader, val_loader = self.get_dataloaders(
dataset,
batch_size,
val_frac=val_frac,
train_transform=train_transform,
val_transform=val_transform,
dataset_split_kwargs=dataset_split_kwargs,
dataset_balance_kwargs=dataset_balance_kwargs,
dataloader_kwargs=dataloader_kwargs,
)
loaders = [train_loader, val_loader, *aux_loaders]
loader_labels = ["train", "val", *aux_loader_labels]
# evaluate model on dataloaders once before training starts
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
self._eval_loaders(train_loader, val_loader, aux_loaders)
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
while self._epoch <= max_epochs and not self._converged(
self._epoch, stop_after_epochs
@@ -464,22 +427,14 @@ class Trainer[I, K: EstimatorKwargs]:
print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time()
self._train_epoch(
train_loader,
batch_estimator_map,
optimizers,
max_grad_norm,
)
self._train_epoch(train_loader, optimizers, max_grad_norm)
epoch_end_time = time.time() - epoch_start_time
self._log_event("train", "epoch_duration", epoch_end_time)
loss_sum_map = self._eval_loaders(
loaders,
batch_estimator_map,
loader_labels,
train_loss, val_loss, _ = self._eval_loaders(
train_loader, val_loader, aux_loaders
)
val_loss_sums = loss_sum_map["val"]
self._val_loss = sum(val_loss_sums) / len(val_loader)
self._conv_loss = sum(val_loss) if val_loss else sum(train_loss)
if self._epoch % summarize_every == 0:
self._summarize()
@@ -492,8 +447,8 @@ class Trainer[I, K: EstimatorKwargs]:
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
converged = False
if epoch == 1 or self._val_loss < self._best_val_loss:
self._best_val_loss = self._val_loss
if epoch == 1 or self._conv_loss < self._best_val_loss:
self._best_val_loss = self._conv_loss
self._stagnant_epochs = 0
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
else:
@@ -505,110 +460,24 @@ class Trainer[I, K: EstimatorKwargs]:
return converged
@staticmethod
def get_dataloaders(
dataset: BatchedDataset,
batch_size: int,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
) -> tuple[DataLoader, DataLoader]:
"""
Create training and validation dataloaders for the provided dataset.
.. todo::
Decide on policy for empty val dataloaders
"""
if dataset_split_kwargs is None:
dataset_split_kwargs = {}
if dataset_balance_kwargs is not None:
dataset.balance(**dataset_balance_kwargs)
if val_frac <= 0:
dataset.post_transform = train_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(dataset)),
"num_workers": 0,
"shuffle": True,
}
if dataloader_kwargs is not None:
train_loader_kwargs: LoaderKwargs = {
**train_loader_kwargs,
**dataloader_kwargs
}
return (
DataLoader(dataset, **train_loader_kwargs),
DataLoader(Dataset())
)
train_dataset, val_dataset = dataset.split(
[1 - val_frac, val_frac],
**dataset_split_kwargs,
)
# Dataset.split() returns light Subset objects of shallow copies of the
# underlying dataset; can change the transform attribute of both splits
# w/o overwriting
train_dataset.post_transform = train_transform
val_dataset.post_transform = val_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(train_dataset)),
"num_workers": 0,
"shuffle": True,
}
val_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(val_dataset)),
"num_workers": 0,
"shuffle": True, # shuffle to prevent homogeneous val batches
}
if dataloader_kwargs is not None:
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
return train_loader, val_loader
def _summarize(self) -> None:
"""
Flush the training summary to the TensorBoard summary writer.
.. note:: Possibly undesirable behavior
Currently, this method aggregates metrics for the epoch summary
across all logged items *in between summarize calls*. For instance,
if I'm logging every 10 epochs, the stats at epoch=10 are actually
averages from epochs 1-10.
Flush the training summary to the TensorBoard summary writer and print
metrics for the current epoch.
"""
epoch_values = defaultdict(lambda: defaultdict(list))
for group, records in self._summary.items():
for name, steps in records.items():
for step, value in steps.items():
self._writer.add_scalar(f"{group}-{name}", value, step)
if step == self._epoch:
epoch_values[group][name].append(value)
print(f"==== Epoch [{self._epoch}] summary ====")
for group, records in epoch_values.items():
for name, values in records.items():
mean_value = torch.tensor(values).mean().item()
print(
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
)
for (group, name), epoch_map in self._summary.items():
for epoch, values in epoch_map.items():
mean = torch.tensor(values).mean().item()
self._writer.add_scalar(f"{group}-{name}", mean, epoch)
if epoch == self._epoch:
print(
f"> ({len(values)}) [{group}] {name} :: {mean:.2f}"
)
self._writer.flush()
self._summary = defaultdict(lambda: defaultdict(dict))
self._summary = defaultdict(lambda: defaultdict(list))
def _get_optimizer_parameters(
self,
@@ -622,33 +491,10 @@ class Trainer[I, K: EstimatorKwargs]:
]
def _log_event(self, group: str, name: str, value: float) -> None:
self._summary[group][name][self._epoch] = value
self._event_log[self._session_name][group][name][self._epoch] = value
session, epoch = self._session_name, self._epoch
def get_batch_outputs[B](
self,
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> Tensor:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
output = self.estimator(**est_kwargs)[0]
output = output.detach().cpu()
return output
def get_batch_metrics[B](
self,
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> dict[str, float]:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
metrics = self.estimator.metrics(**est_kwargs)
return metrics
self._summary[group, name][epoch].append(value)
self._event_log[session][group][name][epoch].append(value)
def save_model(self) -> None:
"""

10
trainlib/utils/map.py Normal file
View File

@@ -0,0 +1,10 @@
from collections import defaultdict
def nested_defaultdict(
depth: int,
final: type = dict,
) -> defaultdict:
if depth == 1:
return defaultdict(final)
return defaultdict(lambda: nested_defaultdict(depth - 1, final))

View File

@@ -11,7 +11,7 @@ from trainlib.dataset import BatchedDataset
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
class LoaderKwargs(TypedDict, total=False):
batch_size: int
batch_size: int | None
shuffle: bool
sampler: Sampler | Iterable | None
batch_sampler: Sampler[list] | Iterable[list] | None