694 lines
27 KiB
Python
694 lines
27 KiB
Python
"""
|
|
Core interface for training ``Estimators`` with ``Datasets``
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import logging
|
|
from io import BytesIO
|
|
from copy import deepcopy
|
|
from typing import Any, Self
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
from torch import cuda, Tensor
|
|
from torch.optim import Optimizer
|
|
from torch.nn.utils import clip_grad_norm_
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from trainlib.dataset import BatchedDataset
|
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
|
from trainlib.transform import Transform
|
|
from trainlib.utils.type import (
|
|
SplitKwargs,
|
|
LoaderKwargs,
|
|
BalanceKwargs,
|
|
)
|
|
from trainlib.utils.module import ModelWrapper
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Trainer[I, K: EstimatorKwargs]:
|
|
"""
|
|
Training interface for optimizing parameters of ``Estimators`` with
|
|
``Datasets``.
|
|
|
|
This class is generic to a dataset item type ``I`` and an estimator kwarg
|
|
type ``K``. These are the two primary components ``Trainer`` objects need
|
|
to coordinate: they ultimately rely on a provided map to ensure data items
|
|
(type ``I``) from a dataset are appropriately routed as inputs to key
|
|
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
|
|
of type ``K``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
estimator: Estimator[K],
|
|
device: str | None = None,
|
|
chkpt_dir: str = "chkpt/",
|
|
tblog_dir: str = "tblog/",
|
|
) -> None:
|
|
"""
|
|
Parameters:
|
|
estimator: ``Estimator`` model object
|
|
device: device on which to carry out training
|
|
chkpt_dir: directory to write model checkpoints
|
|
tblog_dir: directory to write TensorBoard logs
|
|
"""
|
|
|
|
self.device: str
|
|
if device is None:
|
|
self.device = "cuda" if cuda.is_available() else "cpu"
|
|
else:
|
|
self.device = device
|
|
|
|
logger.info(f"> Trainer device: {self.device}")
|
|
if self.device.startswith("cuda"):
|
|
if torch.cuda.is_available():
|
|
# extra cuda details
|
|
logger.info(f"| > {cuda.device_count()=}")
|
|
logger.info(f"| > {cuda.current_device()=}")
|
|
logger.info(f"| > {cuda.get_device_name()=}")
|
|
logger.info(f"| > {cuda.get_device_capability()=}")
|
|
|
|
# memory info (in GB)
|
|
gb = 1024**3
|
|
memory_allocated = cuda.memory_allocated() / gb
|
|
memory_reserved = cuda.memory_reserved() / gb
|
|
memory_total = cuda.get_device_properties(0).total_memory / gb
|
|
|
|
logger.info("| > CUDA memory:")
|
|
logger.info(f"| > {memory_total=:.2f}GB")
|
|
logger.info(f"| > {memory_reserved=:.2f}GB")
|
|
logger.info(f"| > {memory_allocated=:.2f}GB")
|
|
else:
|
|
logger.warning("| > CUDA device specified but not available")
|
|
else:
|
|
logger.info("| > Using CPU device - no additional device info")
|
|
|
|
self.estimator = estimator
|
|
self.estimator.to(self.device)
|
|
|
|
self.chkpt_dir = Path(chkpt_dir).resolve()
|
|
self.tblog_dir = Path(tblog_dir).resolve()
|
|
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
"""
|
|
Set initial tracking parameters for the primary training loop.
|
|
"""
|
|
|
|
self._epoch: int = 1
|
|
self._summary = defaultdict(lambda: defaultdict(dict))
|
|
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
|
|
|
|
self._val_loss = float("inf")
|
|
self._best_val_loss = float("inf")
|
|
self._stagnant_epochs = 0
|
|
self._best_model_state_dict: dict[str, Any] = {}
|
|
|
|
def _train_epoch(
|
|
self,
|
|
train_loader: DataLoader,
|
|
batch_estimator_map: Callable[[I, Self], K],
|
|
optimizers: tuple[Optimizer, ...],
|
|
max_grad_norm: float | None = None,
|
|
) -> list[float]:
|
|
"""
|
|
Train the estimator for a single epoch.
|
|
|
|
.. admonition:: On summary writers
|
|
|
|
Estimators can have several optimizers, and therefore can emit
|
|
several losses. This is a fairly unique case, but it's needed when
|
|
we want to optimize particular parameters in a particular order
|
|
(as in multi-model architectures, e.g., GANs). Point being: we
|
|
always iterate over optimizers/losses, even in the common case
|
|
where there's just a single value, and we index collections across
|
|
batches accordingly.
|
|
|
|
A few of the trackers, with the same size as the number of
|
|
optimizers:
|
|
|
|
- ``train_loss_sums``: tracks loss sums across all batches for the
|
|
epoch, used to update the loop preview text after each batch with
|
|
the current average loss
|
|
- ``train_loss_items``: collects current batch losses, recorded by
|
|
the TB writer
|
|
|
|
If there are ``M`` optimizers/losses, we log ``M`` loss terms to
|
|
the TB writer after each *batch* (not epoch). We could aggregate at
|
|
an epoch level, but parameter updates take place after each batch,
|
|
so large model changes can occur over the course of an epoch
|
|
(whereas the model remains the same over the course batch evals).
|
|
"""
|
|
|
|
loss_sums = []
|
|
self.estimator.train()
|
|
with tqdm(train_loader, unit="batch") as batches:
|
|
for i, batch_data in enumerate(batches):
|
|
est_kwargs = batch_estimator_map(batch_data, self)
|
|
losses = self.estimator.loss(**est_kwargs)
|
|
|
|
for o_idx, (loss, optimizer) in enumerate(
|
|
zip(losses, optimizers, strict=True)
|
|
):
|
|
if len(loss_sums) <= o_idx:
|
|
loss_sums.append(0.0)
|
|
loss_sums[o_idx] += loss.item()
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
|
|
# clip gradients for optimizer's parameters
|
|
if max_grad_norm is not None:
|
|
clip_grad_norm_(
|
|
self._get_optimizer_parameters(optimizer),
|
|
max_norm=max_grad_norm
|
|
)
|
|
|
|
optimizer.step()
|
|
|
|
# set loop loss to running average (reducing if multi-loss)
|
|
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
|
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
|
|
|
# step estimator hyperparam schedules
|
|
self.estimator.epoch_step()
|
|
|
|
return loss_sums
|
|
|
|
def _eval_epoch(
|
|
self,
|
|
loader: DataLoader,
|
|
batch_estimator_map: Callable[[I, Self], K],
|
|
loader_label: str,
|
|
) -> list[float]:
|
|
"""
|
|
Perform and record validation scores for a single epoch.
|
|
|
|
.. On summary writers::
|
|
|
|
See the similarly titled note for ``_train_epoch()`` for general
|
|
remarks about optimizers and how we're recording losses/metrics.
|
|
The same mostly applies here in the validation setting, but we
|
|
crucially aren't stepping forward ``_step`` between batches. This
|
|
means that, even though we're writing losses and metrics once for
|
|
each val batch, those values are simply piling up under the same
|
|
summary item name. This is consistent with how we report training
|
|
items: the model isn't changing across val batches, so these don't
|
|
get plotted at different step points (which might be interpreted as
|
|
performance along model progression when it's actually variation
|
|
across batches). Currently, this "piling" means we actually write
|
|
batch values to the same event name at the same step rather than
|
|
manually averaging beforehand; we defer the handling to TB, and
|
|
although that may technically be discouraged, the val plots render
|
|
collections effectively as a vertical line between the min and max
|
|
value (which I find to be a satisfactory way to view cross-batch
|
|
variation during each val epoch).
|
|
"""
|
|
|
|
loss_sums = []
|
|
self.estimator.eval()
|
|
with tqdm(loader, unit="batch") as batches:
|
|
for i, batch_data in enumerate(batches):
|
|
est_kwargs = batch_estimator_map(batch_data, self)
|
|
losses = self.estimator.loss(**est_kwargs)
|
|
|
|
# one-time logging
|
|
if self._epoch == 0:
|
|
self._writer.add_graph(
|
|
ModelWrapper(self.estimator), est_kwargs
|
|
)
|
|
# once-per-epoch logging
|
|
if i == 0:
|
|
self.estimator.epoch_write(
|
|
self._writer,
|
|
step=self._epoch,
|
|
group=loader_label,
|
|
**est_kwargs
|
|
)
|
|
|
|
loss_items = []
|
|
for o_idx, loss in enumerate(losses):
|
|
if len(loss_sums) <= o_idx:
|
|
loss_sums.append(0.0)
|
|
|
|
loss_item = loss.item()
|
|
loss_sums[o_idx] += loss_item
|
|
loss_items.append(loss_item)
|
|
|
|
# set loop loss to running average (reducing if multi-loss)
|
|
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
|
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
|
|
|
# log individual loss terms after each batch
|
|
for o_idx, loss_item in enumerate(loss_items):
|
|
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
|
|
|
|
# log metrics for batch
|
|
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
|
for metric_name, metric_value in estimator_metrics.items():
|
|
self._log_event(loader_label, metric_name, metric_value)
|
|
|
|
return loss_sums
|
|
|
|
def _eval_loaders(
|
|
self,
|
|
loaders: list[DataLoader],
|
|
batch_estimator_map: Callable[[I, Self], K],
|
|
loader_labels: list[str],
|
|
) -> dict[str, list[float]]:
|
|
"""
|
|
Evaluate estimator over each provided dataloader.
|
|
|
|
This streamlines the general train/val/etc evaluation pipeline during
|
|
training. This triggers logging of summary items, TensorBoard writes,
|
|
etc, and is therefore exclusively intended for use inside the primary
|
|
``train()`` loop.
|
|
|
|
If looking to use this method publicly, you should instead work with
|
|
individual dataloaders and call helper methods like
|
|
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
|
|
batches. This will have no internal side effects and provides much more
|
|
information (just aggregated losses are provided here).
|
|
"""
|
|
|
|
return {
|
|
label: self._eval_epoch(loader, batch_estimator_map, label)
|
|
for loader, label in zip(loaders, loader_labels, strict=True)
|
|
}
|
|
|
|
def train[B](
|
|
self,
|
|
dataset: BatchedDataset[..., ..., I],
|
|
batch_estimator_map: Callable[[B, Self], K],
|
|
lr: float = 1e-3,
|
|
eps: float = 1e-8,
|
|
max_grad_norm: float | None = None,
|
|
max_epochs: int = 10,
|
|
stop_after_epochs: int = 5,
|
|
batch_size: int = 256,
|
|
val_frac: float = 0.1,
|
|
train_transform: Transform | None = None,
|
|
val_transform: Transform | None = None,
|
|
dataset_split_kwargs: SplitKwargs | None = None,
|
|
dataset_balance_kwargs: BalanceKwargs | None = None,
|
|
dataloader_kwargs: LoaderKwargs | None = None,
|
|
summarize_every: int = 1,
|
|
chkpt_every: int = 1,
|
|
resume_latest: bool = False,
|
|
session_name: str | None = None,
|
|
summary_writer: SummaryWriter | None = None,
|
|
aux_loaders: list[DataLoader] | None = None,
|
|
aux_loader_labels: list[str] | None = None,
|
|
) -> Estimator:
|
|
"""
|
|
TODO: consider making the dataloader ``collate_fn`` an explicit
|
|
parameter with a type signature that reflects ``B``, connecting the
|
|
``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader``
|
|
in-house to allow a generic around ``B``
|
|
|
|
Note: this method attempts to implement a general scheme for passing
|
|
needed items to the estimator's loss function from the dataloader. The
|
|
abstract ``Estimator`` base only requires the model output be provided
|
|
for any given loss calculation, but concrete estimators will often
|
|
require additional arguments (e.g., labels or length masks, as is the
|
|
case with sequential models). In any case, this method defers any
|
|
further logic to the ``loss`` method of the underlying estimator, so
|
|
one should take care to synchronize the sample structure with `dataset`
|
|
to match that expected by ``self.estimator.loss(...)``.
|
|
|
|
.. admonition:: On ``batch_estimator_map``
|
|
|
|
Dataloader collate functions are responsible for mapping a
|
|
collection of items into an item of collections, roughly speaking.
|
|
If items are tuples of tensors,
|
|
|
|
.. code-block:: text
|
|
|
|
[
|
|
( [1, 1], [1, 1] ),
|
|
( [2, 2], [2, 2] ),
|
|
( [3, 3], [3, 3] ),
|
|
]
|
|
|
|
the collate function maps back into the item skeleton, producing a
|
|
single tuple of (stacked) tensors
|
|
|
|
.. code-block:: text
|
|
|
|
( [[1, 1],
|
|
[2, 2],
|
|
[3, 3]],
|
|
|
|
[[1, 1],
|
|
[2, 2],
|
|
[3, 3]] )
|
|
|
|
This function should map from batches - which *may* be item
|
|
shaped, i.e., have an ``I`` skeleton, even if stacked items may be
|
|
different on the inside - into estimator keyword arguments (type
|
|
``K``). Collation behavior from a DataLoader (which can be
|
|
customized) doesn't consistently yield a known type shape, however,
|
|
so it's not appropriate to use ``I`` as the callable param type.
|
|
|
|
.. admonition:: On session management
|
|
|
|
This method works around an implicit notion of training sessions.
|
|
Estimators are set during instantiation and effectively coupled
|
|
with ``Trainer`` instances, but datasets can be supplied
|
|
dynamically here. One can, for instance, run under one condition
|
|
(specific dataset, number of epochs, etc), then resume later under
|
|
another. Relevant details persist across calls: the estimator is
|
|
still attached, best val scores stowed, current epoch tracked. By
|
|
default, new ``session_names`` are always generated, but you can
|
|
write to the same TB location if you using the same
|
|
``session_name`` across calls; that's about as close to a direct
|
|
training resume as you could want.
|
|
|
|
If restarting training on new datasets, including short
|
|
fine-tuning on training-plus-validation data, it's often sensible
|
|
to call ``.reset()`` between ``.train()`` calls. While the same
|
|
estimator will be used, tracked variables will be wiped; subsequent
|
|
model updates take place under a fresh epoch, no val losses, and be
|
|
logged under a separate TB session. This is the general approach to
|
|
"piecemeal" training, i.e., incremental model updates under varying
|
|
conditions (often just data changes).
|
|
|
|
.. warning::
|
|
|
|
Validation convergence when there are multiple losses may be
|
|
ambiguous. These are cases where certain parameter sets are
|
|
optimized independently; the sum over these losses may not reflect
|
|
expected or consistent behavior. For instance, we may record a low
|
|
cumulative loss early with a small 1st loss and moderate 2nd loss,
|
|
while later encountering a moderate 1st lost and small 2nd loss. We
|
|
might prefer the latter case, while ``_converged()`` will stick to
|
|
the former -- we need to consider possible weighting across losses,
|
|
or storing possibly several best models (e.g., for each loss, the
|
|
model that scores best, plus the one scoring best cumulatively,
|
|
etc).
|
|
|
|
Parameters:
|
|
dataset: dataset to train the estimator
|
|
batch_estimator_map: function mapping from batch data to expected
|
|
estimator kwargs
|
|
lr: learning rate (default: 1e-3)
|
|
eps: adam EPS (default: 1e-8)
|
|
max_grad_norm: upper bound to use when clipping gradients. If left
|
|
as ``None``, no gradient clipping is performed.
|
|
max_epochs: maximum number of training epochs
|
|
stop_after_epochs: number of epochs with stagnant validation losses
|
|
to allow before early stopping. If training stops earlier, the
|
|
parameters for the best recorded validation score are loaded
|
|
into the estimator before the method returns. If
|
|
`stop_after_epochs >= max_epochs`, the estimator will train
|
|
over all epochs and return as is, irrespective of validation
|
|
scores.
|
|
batch_size: size of batch to use when training on the provided
|
|
dataset
|
|
val_split_frac: fraction of dataset to use for validation
|
|
chkpt_every: how often model checkpoints should be saved
|
|
resume_latest: resume training from the latest available checkpoint
|
|
in the `chkpt_dir`
|
|
"""
|
|
|
|
logger.info("> Begin train loop:")
|
|
logger.info(f"| > {lr=}")
|
|
logger.info(f"| > {eps=}")
|
|
logger.info(f"| > {max_epochs=}")
|
|
logger.info(f"| > {batch_size=}")
|
|
logger.info(f"| > {val_frac=}")
|
|
logger.info(f"| > {chkpt_every=}")
|
|
logger.info(f"| > {resume_latest=}")
|
|
logger.info(f"| > with device: {self.device}")
|
|
logger.info(f"| > core count: {os.cpu_count()}")
|
|
|
|
self._session_name = session_name or str(int(time.time()))
|
|
tblog_path = Path(self.tblog_dir, self._session_name)
|
|
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
|
|
|
aux_loaders = aux_loaders or []
|
|
aux_loader_labels = aux_loader_labels or []
|
|
|
|
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
|
train_loader, val_loader = self.get_dataloaders(
|
|
dataset,
|
|
batch_size,
|
|
val_frac=val_frac,
|
|
train_transform=train_transform,
|
|
val_transform=val_transform,
|
|
dataset_split_kwargs=dataset_split_kwargs,
|
|
dataset_balance_kwargs=dataset_balance_kwargs,
|
|
dataloader_kwargs=dataloader_kwargs,
|
|
)
|
|
loaders = [train_loader, val_loader, *aux_loaders]
|
|
loader_labels = ["train", "val", *aux_loader_labels]
|
|
|
|
# evaluate model on dataloaders once before training starts
|
|
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
|
|
|
|
while self._epoch <= max_epochs and not self._converged(
|
|
self._epoch, stop_after_epochs
|
|
):
|
|
train_frac = f"{self._epoch}/{max_epochs}"
|
|
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
|
print(f"Training epoch {train_frac}...")
|
|
print(f"Stagnant epochs {stag_frac}...")
|
|
|
|
epoch_start_time = time.time()
|
|
self._train_epoch(
|
|
train_loader,
|
|
batch_estimator_map,
|
|
optimizers,
|
|
max_grad_norm,
|
|
)
|
|
epoch_end_time = time.time() - epoch_start_time
|
|
self._log_event("train", "epoch_duration", epoch_end_time)
|
|
|
|
loss_sum_map = self._eval_loaders(
|
|
loaders,
|
|
batch_estimator_map,
|
|
loader_labels,
|
|
)
|
|
val_loss_sums = loss_sum_map["val"]
|
|
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
|
|
|
if self._epoch % summarize_every == 0:
|
|
self._summarize()
|
|
if self._epoch % chkpt_every == 0:
|
|
self.save_model()
|
|
self._epoch += 1
|
|
|
|
return self.estimator
|
|
|
|
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
|
converged = False
|
|
|
|
if epoch == 1 or self._val_loss < self._best_val_loss:
|
|
self._best_val_loss = self._val_loss
|
|
self._stagnant_epochs = 0
|
|
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
|
else:
|
|
self._stagnant_epochs += 1
|
|
|
|
if self._stagnant_epochs >= stop_after_epochs:
|
|
self.estimator.load_state_dict(self._best_model_state_dict)
|
|
converged = True
|
|
|
|
return converged
|
|
|
|
@staticmethod
|
|
def get_dataloaders(
|
|
dataset: BatchedDataset,
|
|
batch_size: int,
|
|
val_frac: float = 0.1,
|
|
train_transform: Transform | None = None,
|
|
val_transform: Transform | None = None,
|
|
dataset_split_kwargs: SplitKwargs | None = None,
|
|
dataset_balance_kwargs: BalanceKwargs | None = None,
|
|
dataloader_kwargs: LoaderKwargs | None = None,
|
|
) -> tuple[DataLoader, DataLoader]:
|
|
"""
|
|
Create training and validation dataloaders for the provided dataset.
|
|
|
|
.. todo::
|
|
|
|
Decide on policy for empty val dataloaders
|
|
"""
|
|
|
|
if dataset_split_kwargs is None:
|
|
dataset_split_kwargs = {}
|
|
|
|
if dataset_balance_kwargs is not None:
|
|
dataset.balance(**dataset_balance_kwargs)
|
|
|
|
if val_frac <= 0:
|
|
dataset.post_transform = train_transform
|
|
train_loader_kwargs: LoaderKwargs = {
|
|
"batch_size": min(batch_size, len(dataset)),
|
|
"num_workers": 0,
|
|
"shuffle": True,
|
|
}
|
|
if dataloader_kwargs is not None:
|
|
train_loader_kwargs: LoaderKwargs = {
|
|
**train_loader_kwargs,
|
|
**dataloader_kwargs
|
|
}
|
|
|
|
return (
|
|
DataLoader(dataset, **train_loader_kwargs),
|
|
DataLoader(Dataset())
|
|
)
|
|
|
|
train_dataset, val_dataset = dataset.split(
|
|
[1 - val_frac, val_frac],
|
|
**dataset_split_kwargs,
|
|
)
|
|
|
|
# Dataset.split() returns light Subset objects of shallow copies of the
|
|
# underlying dataset; can change the transform attribute of both splits
|
|
# w/o overwriting
|
|
train_dataset.post_transform = train_transform
|
|
val_dataset.post_transform = val_transform
|
|
|
|
train_loader_kwargs: LoaderKwargs = {
|
|
"batch_size": min(batch_size, len(train_dataset)),
|
|
"num_workers": 0,
|
|
"shuffle": True,
|
|
}
|
|
val_loader_kwargs: LoaderKwargs = {
|
|
"batch_size": min(batch_size, len(val_dataset)),
|
|
"num_workers": 0,
|
|
"shuffle": True, # shuffle to prevent homogeneous val batches
|
|
}
|
|
|
|
if dataloader_kwargs is not None:
|
|
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
|
|
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
|
|
|
|
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
|
|
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
|
|
|
|
return train_loader, val_loader
|
|
|
|
def _summarize(self) -> None:
|
|
"""
|
|
Flush the training summary to the TensorBoard summary writer.
|
|
|
|
.. note:: Possibly undesirable behavior
|
|
|
|
Currently, this method aggregates metrics for the epoch summary
|
|
across all logged items *in between summarize calls*. For instance,
|
|
if I'm logging every 10 epochs, the stats at epoch=10 are actually
|
|
averages from epochs 1-10.
|
|
"""
|
|
|
|
epoch_values = defaultdict(lambda: defaultdict(list))
|
|
for group, records in self._summary.items():
|
|
for name, steps in records.items():
|
|
for step, value in steps.items():
|
|
self._writer.add_scalar(f"{group}-{name}", value, step)
|
|
if step == self._epoch:
|
|
epoch_values[group][name].append(value)
|
|
|
|
print(f"==== Epoch [{self._epoch}] summary ====")
|
|
for group, records in epoch_values.items():
|
|
for name, values in records.items():
|
|
mean_value = torch.tensor(values).mean().item()
|
|
print(
|
|
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
|
|
)
|
|
|
|
self._writer.flush()
|
|
self._summary = defaultdict(lambda: defaultdict(dict))
|
|
|
|
def _get_optimizer_parameters(
|
|
self,
|
|
optimizer: torch.optim.Optimizer,
|
|
) -> list[Tensor]:
|
|
return [
|
|
param
|
|
for param_group in optimizer.param_groups
|
|
for param in param_group["params"]
|
|
if param.grad is not None
|
|
]
|
|
|
|
def _log_event(self, group: str, name: str, value: float) -> None:
|
|
self._summary[group][name][self._epoch] = value
|
|
self._event_log[self._session_name][group][name][self._epoch] = value
|
|
|
|
def get_batch_outputs[B](
|
|
self,
|
|
batch: B,
|
|
batch_estimator_map: Callable[[B, Self], K],
|
|
) -> Tensor:
|
|
self.estimator.eval()
|
|
|
|
est_kwargs = batch_estimator_map(batch, self)
|
|
output = self.estimator(**est_kwargs)[0]
|
|
output = output.detach().cpu()
|
|
|
|
return output
|
|
|
|
def get_batch_metrics[B](
|
|
self,
|
|
batch: B,
|
|
batch_estimator_map: Callable[[B, Self], K],
|
|
) -> dict[str, float]:
|
|
self.estimator.eval()
|
|
|
|
est_kwargs = batch_estimator_map(batch, self)
|
|
metrics = self.estimator.metrics(**est_kwargs)
|
|
|
|
return metrics
|
|
|
|
def save_model(self) -> None:
|
|
"""
|
|
Save a model checkpoint.
|
|
"""
|
|
|
|
model_buff = BytesIO()
|
|
torch.save(self.estimator.state_dict(), model_buff)
|
|
model_buff.seek(0)
|
|
|
|
model_class = self.estimator.__class__.__name__
|
|
chkpt_name = f"m_{model_class}-e_{self._epoch}.pth"
|
|
|
|
chkpt_dir = Path(self.chkpt_dir, self._session_name)
|
|
chkpt_path = Path(chkpt_dir, chkpt_name)
|
|
|
|
chkpt_dir.mkdir(parents=True, exist_ok=True)
|
|
chkpt_path.write_bytes(model_buff.getvalue())
|
|
|
|
def load_model(self, chkpt_dir: str, epoch: int) -> None:
|
|
"""
|
|
Load a model checkpoint from a given epoch.
|
|
|
|
Note that this assumes the model was saved via
|
|
``Trainer.save_model()``, and the estimator provided to this
|
|
``Trainer`` instance matches the architecture of the checkpoint model
|
|
being loaded.
|
|
|
|
Parameters:
|
|
epoch: epoch of saved model
|
|
chkpt_dir:
|
|
"""
|
|
|
|
model_class = self.estimator.__class__.__name__
|
|
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
|
|
chkpt_path = Path(chkpt_dir, chkpt_name)
|
|
|
|
model_buff = BytesIO(chkpt_path.read_bytes())
|
|
model_buff.seek(0)
|
|
|
|
model_dict = torch.load(model_buff, weights_only=True)
|
|
self.estimator.load_state_dict(model_dict)
|