Files
trainlib/trainlib/trainer.py

694 lines
27 KiB
Python

"""
Core interface for training ``Estimators`` with ``Datasets``
"""
import os
import time
import logging
from io import BytesIO
from copy import deepcopy
from typing import Any, Self
from pathlib import Path
from collections import defaultdict
from collections.abc import Callable
import torch
from tqdm import tqdm
from torch import cuda, Tensor
from torch.optim import Optimizer
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from trainlib.dataset import BatchedDataset
from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.transform import Transform
from trainlib.utils.type import (
SplitKwargs,
LoaderKwargs,
BalanceKwargs,
)
from trainlib.utils.module import ModelWrapper
logger: logging.Logger = logging.getLogger(__name__)
class Trainer[I, K: EstimatorKwargs]:
"""
Training interface for optimizing parameters of ``Estimators`` with
``Datasets``.
This class is generic to a dataset item type ``I`` and an estimator kwarg
type ``K``. These are the two primary components ``Trainer`` objects need
to coordinate: they ultimately rely on a provided map to ensure data items
(type ``I``) from a dataset are appropriately routed as inputs to key
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
of type ``K``.
"""
def __init__(
self,
estimator: Estimator[K],
device: str | None = None,
chkpt_dir: str = "chkpt/",
tblog_dir: str = "tblog/",
) -> None:
"""
Parameters:
estimator: ``Estimator`` model object
device: device on which to carry out training
chkpt_dir: directory to write model checkpoints
tblog_dir: directory to write TensorBoard logs
"""
self.device: str
if device is None:
self.device = "cuda" if cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"> Trainer device: {self.device}")
if self.device.startswith("cuda"):
if torch.cuda.is_available():
# extra cuda details
logger.info(f"| > {cuda.device_count()=}")
logger.info(f"| > {cuda.current_device()=}")
logger.info(f"| > {cuda.get_device_name()=}")
logger.info(f"| > {cuda.get_device_capability()=}")
# memory info (in GB)
gb = 1024**3
memory_allocated = cuda.memory_allocated() / gb
memory_reserved = cuda.memory_reserved() / gb
memory_total = cuda.get_device_properties(0).total_memory / gb
logger.info("| > CUDA memory:")
logger.info(f"| > {memory_total=:.2f}GB")
logger.info(f"| > {memory_reserved=:.2f}GB")
logger.info(f"| > {memory_allocated=:.2f}GB")
else:
logger.warning("| > CUDA device specified but not available")
else:
logger.info("| > Using CPU device - no additional device info")
self.estimator = estimator
self.estimator.to(self.device)
self.chkpt_dir = Path(chkpt_dir).resolve()
self.tblog_dir = Path(tblog_dir).resolve()
self.reset()
def reset(self) -> None:
"""
Set initial tracking parameters for the primary training loop.
"""
self._epoch: int = 1
self._summary = defaultdict(lambda: defaultdict(dict))
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
self._val_loss = float("inf")
self._best_val_loss = float("inf")
self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch(
self,
train_loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
optimizers: tuple[Optimizer, ...],
max_grad_norm: float | None = None,
) -> list[float]:
"""
Train the estimator for a single epoch.
.. admonition:: On summary writers
Estimators can have several optimizers, and therefore can emit
several losses. This is a fairly unique case, but it's needed when
we want to optimize particular parameters in a particular order
(as in multi-model architectures, e.g., GANs). Point being: we
always iterate over optimizers/losses, even in the common case
where there's just a single value, and we index collections across
batches accordingly.
A few of the trackers, with the same size as the number of
optimizers:
- ``train_loss_sums``: tracks loss sums across all batches for the
epoch, used to update the loop preview text after each batch with
the current average loss
- ``train_loss_items``: collects current batch losses, recorded by
the TB writer
If there are ``M`` optimizers/losses, we log ``M`` loss terms to
the TB writer after each *batch* (not epoch). We could aggregate at
an epoch level, but parameter updates take place after each batch,
so large model changes can occur over the course of an epoch
(whereas the model remains the same over the course batch evals).
"""
loss_sums = []
self.estimator.train()
with tqdm(train_loader, unit="batch") as batches:
for i, batch_data in enumerate(batches):
est_kwargs = batch_estimator_map(batch_data, self)
losses = self.estimator.loss(**est_kwargs)
for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True)
):
if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
loss_sums[o_idx] += loss.item()
optimizer.zero_grad()
loss.backward()
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
clip_grad_norm_(
self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm
)
optimizer.step()
# set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
# step estimator hyperparam schedules
self.estimator.epoch_step()
return loss_sums
def _eval_epoch(
self,
loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
loader_label: str,
) -> list[float]:
"""
Perform and record validation scores for a single epoch.
.. On summary writers::
See the similarly titled note for ``_train_epoch()`` for general
remarks about optimizers and how we're recording losses/metrics.
The same mostly applies here in the validation setting, but we
crucially aren't stepping forward ``_step`` between batches. This
means that, even though we're writing losses and metrics once for
each val batch, those values are simply piling up under the same
summary item name. This is consistent with how we report training
items: the model isn't changing across val batches, so these don't
get plotted at different step points (which might be interpreted as
performance along model progression when it's actually variation
across batches). Currently, this "piling" means we actually write
batch values to the same event name at the same step rather than
manually averaging beforehand; we defer the handling to TB, and
although that may technically be discouraged, the val plots render
collections effectively as a vertical line between the min and max
value (which I find to be a satisfactory way to view cross-batch
variation during each val epoch).
"""
loss_sums = []
self.estimator.eval()
with tqdm(loader, unit="batch") as batches:
for i, batch_data in enumerate(batches):
est_kwargs = batch_estimator_map(batch_data, self)
losses = self.estimator.loss(**est_kwargs)
# one-time logging
if self._epoch == 0:
self._writer.add_graph(
ModelWrapper(self.estimator), est_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=loader_label,
**est_kwargs
)
loss_items = []
for o_idx, loss in enumerate(losses):
if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
loss_item = loss.item()
loss_sums[o_idx] += loss_item
loss_items.append(loss_item)
# set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
# log individual loss terms after each batch
for o_idx, loss_item in enumerate(loss_items):
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._log_event(loader_label, metric_name, metric_value)
return loss_sums
def _eval_loaders(
self,
loaders: list[DataLoader],
batch_estimator_map: Callable[[I, Self], K],
loader_labels: list[str],
) -> dict[str, list[float]]:
"""
Evaluate estimator over each provided dataloader.
This streamlines the general train/val/etc evaluation pipeline during
training. This triggers logging of summary items, TensorBoard writes,
etc, and is therefore exclusively intended for use inside the primary
``train()`` loop.
If looking to use this method publicly, you should instead work with
individual dataloaders and call helper methods like
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
batches. This will have no internal side effects and provides much more
information (just aggregated losses are provided here).
"""
return {
label: self._eval_epoch(loader, batch_estimator_map, label)
for loader, label in zip(loaders, loader_labels, strict=True)
}
def train[B](
self,
dataset: BatchedDataset[..., ..., I],
batch_estimator_map: Callable[[B, Self], K],
lr: float = 1e-3,
eps: float = 1e-8,
max_grad_norm: float | None = None,
max_epochs: int = 10,
stop_after_epochs: int = 5,
batch_size: int = 256,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
summarize_every: int = 1,
chkpt_every: int = 1,
resume_latest: bool = False,
session_name: str | None = None,
summary_writer: SummaryWriter | None = None,
aux_loaders: list[DataLoader] | None = None,
aux_loader_labels: list[str] | None = None,
) -> Estimator:
"""
TODO: consider making the dataloader ``collate_fn`` an explicit
parameter with a type signature that reflects ``B``, connecting the
``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader``
in-house to allow a generic around ``B``
Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The
abstract ``Estimator`` base only requires the model output be provided
for any given loss calculation, but concrete estimators will often
require additional arguments (e.g., labels or length masks, as is the
case with sequential models). In any case, this method defers any
further logic to the ``loss`` method of the underlying estimator, so
one should take care to synchronize the sample structure with `dataset`
to match that expected by ``self.estimator.loss(...)``.
.. admonition:: On ``batch_estimator_map``
Dataloader collate functions are responsible for mapping a
collection of items into an item of collections, roughly speaking.
If items are tuples of tensors,
.. code-block:: text
[
( [1, 1], [1, 1] ),
( [2, 2], [2, 2] ),
( [3, 3], [3, 3] ),
]
the collate function maps back into the item skeleton, producing a
single tuple of (stacked) tensors
.. code-block:: text
( [[1, 1],
[2, 2],
[3, 3]],
[[1, 1],
[2, 2],
[3, 3]] )
This function should map from batches - which *may* be item
shaped, i.e., have an ``I`` skeleton, even if stacked items may be
different on the inside - into estimator keyword arguments (type
``K``). Collation behavior from a DataLoader (which can be
customized) doesn't consistently yield a known type shape, however,
so it's not appropriate to use ``I`` as the callable param type.
.. admonition:: On session management
This method works around an implicit notion of training sessions.
Estimators are set during instantiation and effectively coupled
with ``Trainer`` instances, but datasets can be supplied
dynamically here. One can, for instance, run under one condition
(specific dataset, number of epochs, etc), then resume later under
another. Relevant details persist across calls: the estimator is
still attached, best val scores stowed, current epoch tracked. By
default, new ``session_names`` are always generated, but you can
write to the same TB location if you using the same
``session_name`` across calls; that's about as close to a direct
training resume as you could want.
If restarting training on new datasets, including short
fine-tuning on training-plus-validation data, it's often sensible
to call ``.reset()`` between ``.train()`` calls. While the same
estimator will be used, tracked variables will be wiped; subsequent
model updates take place under a fresh epoch, no val losses, and be
logged under a separate TB session. This is the general approach to
"piecemeal" training, i.e., incremental model updates under varying
conditions (often just data changes).
.. warning::
Validation convergence when there are multiple losses may be
ambiguous. These are cases where certain parameter sets are
optimized independently; the sum over these losses may not reflect
expected or consistent behavior. For instance, we may record a low
cumulative loss early with a small 1st loss and moderate 2nd loss,
while later encountering a moderate 1st lost and small 2nd loss. We
might prefer the latter case, while ``_converged()`` will stick to
the former -- we need to consider possible weighting across losses,
or storing possibly several best models (e.g., for each loss, the
model that scores best, plus the one scoring best cumulatively,
etc).
Parameters:
dataset: dataset to train the estimator
batch_estimator_map: function mapping from batch data to expected
estimator kwargs
lr: learning rate (default: 1e-3)
eps: adam EPS (default: 1e-8)
max_grad_norm: upper bound to use when clipping gradients. If left
as ``None``, no gradient clipping is performed.
max_epochs: maximum number of training epochs
stop_after_epochs: number of epochs with stagnant validation losses
to allow before early stopping. If training stops earlier, the
parameters for the best recorded validation score are loaded
into the estimator before the method returns. If
`stop_after_epochs >= max_epochs`, the estimator will train
over all epochs and return as is, irrespective of validation
scores.
batch_size: size of batch to use when training on the provided
dataset
val_split_frac: fraction of dataset to use for validation
chkpt_every: how often model checkpoints should be saved
resume_latest: resume training from the latest available checkpoint
in the `chkpt_dir`
"""
logger.info("> Begin train loop:")
logger.info(f"| > {lr=}")
logger.info(f"| > {eps=}")
logger.info(f"| > {max_epochs=}")
logger.info(f"| > {batch_size=}")
logger.info(f"| > {val_frac=}")
logger.info(f"| > {chkpt_every=}")
logger.info(f"| > {resume_latest=}")
logger.info(f"| > with device: {self.device}")
logger.info(f"| > core count: {os.cpu_count()}")
self._session_name = session_name or str(int(time.time()))
tblog_path = Path(self.tblog_dir, self._session_name)
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
aux_loaders = aux_loaders or []
aux_loader_labels = aux_loader_labels or []
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
train_loader, val_loader = self.get_dataloaders(
dataset,
batch_size,
val_frac=val_frac,
train_transform=train_transform,
val_transform=val_transform,
dataset_split_kwargs=dataset_split_kwargs,
dataset_balance_kwargs=dataset_balance_kwargs,
dataloader_kwargs=dataloader_kwargs,
)
loaders = [train_loader, val_loader, *aux_loaders]
loader_labels = ["train", "val", *aux_loader_labels]
# evaluate model on dataloaders once before training starts
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
while self._epoch <= max_epochs and not self._converged(
self._epoch, stop_after_epochs
):
train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...")
print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time()
self._train_epoch(
train_loader,
batch_estimator_map,
optimizers,
max_grad_norm,
)
epoch_end_time = time.time() - epoch_start_time
self._log_event("train", "epoch_duration", epoch_end_time)
loss_sum_map = self._eval_loaders(
loaders,
batch_estimator_map,
loader_labels,
)
val_loss_sums = loss_sum_map["val"]
self._val_loss = sum(val_loss_sums) / len(val_loader)
if self._epoch % summarize_every == 0:
self._summarize()
if self._epoch % chkpt_every == 0:
self.save_model()
self._epoch += 1
return self.estimator
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
converged = False
if epoch == 1 or self._val_loss < self._best_val_loss:
self._best_val_loss = self._val_loss
self._stagnant_epochs = 0
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
else:
self._stagnant_epochs += 1
if self._stagnant_epochs >= stop_after_epochs:
self.estimator.load_state_dict(self._best_model_state_dict)
converged = True
return converged
@staticmethod
def get_dataloaders(
dataset: BatchedDataset,
batch_size: int,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
) -> tuple[DataLoader, DataLoader]:
"""
Create training and validation dataloaders for the provided dataset.
.. todo::
Decide on policy for empty val dataloaders
"""
if dataset_split_kwargs is None:
dataset_split_kwargs = {}
if dataset_balance_kwargs is not None:
dataset.balance(**dataset_balance_kwargs)
if val_frac <= 0:
dataset.post_transform = train_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(dataset)),
"num_workers": 0,
"shuffle": True,
}
if dataloader_kwargs is not None:
train_loader_kwargs: LoaderKwargs = {
**train_loader_kwargs,
**dataloader_kwargs
}
return (
DataLoader(dataset, **train_loader_kwargs),
DataLoader(Dataset())
)
train_dataset, val_dataset = dataset.split(
[1 - val_frac, val_frac],
**dataset_split_kwargs,
)
# Dataset.split() returns light Subset objects of shallow copies of the
# underlying dataset; can change the transform attribute of both splits
# w/o overwriting
train_dataset.post_transform = train_transform
val_dataset.post_transform = val_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(train_dataset)),
"num_workers": 0,
"shuffle": True,
}
val_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(val_dataset)),
"num_workers": 0,
"shuffle": True, # shuffle to prevent homogeneous val batches
}
if dataloader_kwargs is not None:
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
return train_loader, val_loader
def _summarize(self) -> None:
"""
Flush the training summary to the TensorBoard summary writer.
.. note:: Possibly undesirable behavior
Currently, this method aggregates metrics for the epoch summary
across all logged items *in between summarize calls*. For instance,
if I'm logging every 10 epochs, the stats at epoch=10 are actually
averages from epochs 1-10.
"""
epoch_values = defaultdict(lambda: defaultdict(list))
for group, records in self._summary.items():
for name, steps in records.items():
for step, value in steps.items():
self._writer.add_scalar(f"{group}-{name}", value, step)
if step == self._epoch:
epoch_values[group][name].append(value)
print(f"==== Epoch [{self._epoch}] summary ====")
for group, records in epoch_values.items():
for name, values in records.items():
mean_value = torch.tensor(values).mean().item()
print(
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
)
self._writer.flush()
self._summary = defaultdict(lambda: defaultdict(dict))
def _get_optimizer_parameters(
self,
optimizer: torch.optim.Optimizer,
) -> list[Tensor]:
return [
param
for param_group in optimizer.param_groups
for param in param_group["params"]
if param.grad is not None
]
def _log_event(self, group: str, name: str, value: float) -> None:
self._summary[group][name][self._epoch] = value
self._event_log[self._session_name][group][name][self._epoch] = value
def get_batch_outputs[B](
self,
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> Tensor:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
output = self.estimator(**est_kwargs)[0]
output = output.detach().cpu()
return output
def get_batch_metrics[B](
self,
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> dict[str, float]:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
metrics = self.estimator.metrics(**est_kwargs)
return metrics
def save_model(self) -> None:
"""
Save a model checkpoint.
"""
model_buff = BytesIO()
torch.save(self.estimator.state_dict(), model_buff)
model_buff.seek(0)
model_class = self.estimator.__class__.__name__
chkpt_name = f"m_{model_class}-e_{self._epoch}.pth"
chkpt_dir = Path(self.chkpt_dir, self._session_name)
chkpt_path = Path(chkpt_dir, chkpt_name)
chkpt_dir.mkdir(parents=True, exist_ok=True)
chkpt_path.write_bytes(model_buff.getvalue())
def load_model(self, chkpt_dir: str, epoch: int) -> None:
"""
Load a model checkpoint from a given epoch.
Note that this assumes the model was saved via
``Trainer.save_model()``, and the estimator provided to this
``Trainer`` instance matches the architecture of the checkpoint model
being loaded.
Parameters:
epoch: epoch of saved model
chkpt_dir:
"""
model_class = self.estimator.__class__.__name__
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
chkpt_path = Path(chkpt_dir, chkpt_name)
model_buff = BytesIO(chkpt_path.read_bytes())
model_buff.seek(0)
model_dict = torch.load(model_buff, weights_only=True)
self.estimator.load_state_dict(model_dict)