refactor Trainer (large), improve dataset/est middleware
This commit is contained in:
102
trainlib/dataloader.py
Normal file
102
trainlib/dataloader.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
This class took me a long time to really settle into. It's a connector, and it
|
||||
feels redundant in many ways, so I've nearly deleted it several times while
|
||||
talking through the design. But in total, I think it serves a clear purpose.
|
||||
Reasons:
|
||||
|
||||
- Need a typed dataloader, even if I know the type of my attached transform
|
||||
- Need a new scope that uses the same base dataset without interfering with the
|
||||
transform attribute; a design that sets or relies on that is subject to
|
||||
conflict
|
||||
|
||||
- Why not just use vanilla DataLoaders?
|
||||
|
||||
I'd like to, but the two reasons above make it clear why this is challenging:
|
||||
I don't get static checks on the structures returned during iteration, and
|
||||
while you can control ad hoc data transforms via dataset ``post_transforms``,
|
||||
things can get messy if you need to do that for many transforms using the
|
||||
same dataset (without copying). Simplest way around this is just a new scope
|
||||
with the same underlying dataset instance and a transform wrapper around the
|
||||
iterator; no interference with object attributes.
|
||||
|
||||
This is really just meant as the minimum viable logic needed to accomplish the
|
||||
above - it's a very lightweight wrapper on the base ``DataLoader`` object.
|
||||
There's an explicit type upper bound ``Kw: EstimatorKwargs``, but it is
|
||||
otherwise a completely general transform over dataloader batches, highlighting
|
||||
that it's *mostly* here to place nice with type checks.
|
||||
"""
|
||||
|
||||
from typing import Unpack
|
||||
from collections.abc import Iterator
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainlib.dataset import BatchedDataset
|
||||
from trainlib.estimator import EstimatorKwargs
|
||||
from trainlib.utils.type import LoaderKwargs
|
||||
|
||||
|
||||
class EstimatorDataLoader[B, Kw: EstimatorKwargs]:
|
||||
"""
|
||||
Data loaders for estimators.
|
||||
|
||||
This class exists to connect batched data from datasets to the expected
|
||||
representation for estimator methods. Datasets may be developed
|
||||
independently from a given model structures, and models should be trainable
|
||||
under any such data. We need a way to ensure the batched groups of items we
|
||||
get from dataloaders match on a type level, i.e., can be reshaped into the
|
||||
expected ``Kw`` signature.
|
||||
|
||||
Note: batch structure ``B`` cannot be directly inferred from type variables
|
||||
exposed by ``BatchedDatasets`` (namely ``R`` and ``I``). What's returned by
|
||||
a data loader wrapping any such dataset can be arbitrary (depending on the
|
||||
``collate_fn``), with default behavior being fairly consistent under nested
|
||||
collections but challenging to accurately type.
|
||||
|
||||
.. todo::
|
||||
|
||||
To log (have changed for Trainer):
|
||||
|
||||
- New compact eval pipeline for train/val/auxiliary dataloaders.
|
||||
Somewhat annoying logic, but handled consistently
|
||||
- Convergence tracker will dynamically use training loss (early
|
||||
stopping) when a validation set isn't provided. Same mechanics for
|
||||
stagnant epochs (although early stopping is generally a little more
|
||||
nuanced, having a rate-based stopper, b/c train loss generally quite
|
||||
monotonic). So that's to be updated, plus room for possible model
|
||||
selection strategies later.
|
||||
- Logging happens at each batch, but we append to an epoch-indexed list
|
||||
and later average. There was a bug in the last round of testing that
|
||||
I didn't pick up where I was just overwriting summaries using the
|
||||
last seen batch.
|
||||
- Reworked general dataset/dataloader handling for main train loop, now
|
||||
accepting objects of this class to bridge estimator and dataset
|
||||
communication. This cleans up the batch mapping model.
|
||||
- TODO: implement a version of this that canonically works with the
|
||||
device passing plus EstimatorKwargs input; this is the last fuzzy bit
|
||||
I think.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: BatchedDataset,
|
||||
**dataloader_kwargs: Unpack[LoaderKwargs],
|
||||
) -> None:
|
||||
self._dataloader = DataLoader(dataset, **dataloader_kwargs)
|
||||
|
||||
def batch_to_est_kwargs(self, batch_data: B) -> Kw:
|
||||
"""
|
||||
.. note::
|
||||
|
||||
Even if we have a concrete shape for the output kwarg dict for base
|
||||
estimators (requiring a tensor "inputs" attribute), we don't
|
||||
presuppose how a given batch object will map into this dict
|
||||
structure.
|
||||
|
||||
return EstimatorKwargs({"inputs":0})
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def __iter__(self) -> Iterator[Kw]:
|
||||
return map(self.batch_to_est_kwargs, self._dataloader)
|
||||
@@ -1,21 +1,21 @@
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainlib.trainer import Trainer
|
||||
from trainlib.estimator import EstimatorKwargs
|
||||
from trainlib.dataloader import EstimatorDataLoader
|
||||
from trainlib.utils.type import AxesArray, SubplotsKwargs
|
||||
|
||||
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
|
||||
type ContextFn = Callable[[plt.Axes, str], None]
|
||||
|
||||
|
||||
class Plotter[B, K: EstimatorKwargs]:
|
||||
class Plotter[Kw: EstimatorKwargs]:
|
||||
"""
|
||||
TODOs:
|
||||
|
||||
@@ -30,10 +30,9 @@ class Plotter[B, K: EstimatorKwargs]:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trainer: Trainer[..., K],
|
||||
dataloaders: list[DataLoader],
|
||||
batch_estimator_map: Callable[[B, Trainer], ...],
|
||||
estimator_to_output_map: Callable[[K], ...],
|
||||
trainer: Trainer[Any, Kw],
|
||||
dataloaders: list[EstimatorDataLoader[Any, Kw]],
|
||||
kw_to_actual: Callable[[Kw], Tensor],
|
||||
dataloader_labels: list[str] | None = None,
|
||||
) -> None:
|
||||
self.trainer = trainer
|
||||
@@ -41,47 +40,21 @@ class Plotter[B, K: EstimatorKwargs]:
|
||||
self.dataloader_labels = (
|
||||
dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
|
||||
)
|
||||
self.batch_estimator_map = batch_estimator_map
|
||||
self.estimator_to_output_map = estimator_to_output_map
|
||||
self.kw_to_actual = kw_to_actual
|
||||
|
||||
self._outputs: list[list[Tensor]] | None = None
|
||||
self._metrics: list[list[dict[str, float]]] | None = None
|
||||
|
||||
self._batch_outputs_fn = partial(
|
||||
self.trainer.get_batch_outputs,
|
||||
batch_estimator_map=batch_estimator_map
|
||||
)
|
||||
self._batch_metrics_fn = partial(
|
||||
self.trainer.get_batch_metrics,
|
||||
batch_estimator_map=batch_estimator_map
|
||||
)
|
||||
|
||||
self._data_tuples = None
|
||||
|
||||
@property
|
||||
def outputs(self) -> list[list[Tensor]]:
|
||||
if self._outputs is None:
|
||||
self._outputs = [
|
||||
list(map(self._batch_outputs_fn, loader))
|
||||
for loader in self.dataloaders
|
||||
]
|
||||
return self._outputs
|
||||
|
||||
@property
|
||||
def metrics(self) -> list[list[dict[str, float]]]:
|
||||
if self._metrics is None:
|
||||
self._metrics = [
|
||||
list(map(self._batch_metrics_fn, loader))
|
||||
for loader in self.dataloaders
|
||||
]
|
||||
return self._metrics
|
||||
|
||||
@property
|
||||
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
|
||||
"""
|
||||
Produce data items; to be cached. Zip later with axes
|
||||
"""
|
||||
|
||||
self.trainer.estimator.eval()
|
||||
|
||||
if self._data_tuples is not None:
|
||||
return self._data_tuples
|
||||
|
||||
@@ -89,10 +62,14 @@ class Plotter[B, K: EstimatorKwargs]:
|
||||
for i, loader in enumerate(self.dataloaders):
|
||||
label = self.dataloader_labels[i]
|
||||
|
||||
batch = next(iter(loader))
|
||||
est_kwargs = self.batch_estimator_map(batch, self.trainer)
|
||||
actual = self.estimator_to_output_map(est_kwargs)
|
||||
output = self._batch_outputs_fn(batch)
|
||||
actual = torch.cat([
|
||||
self.kw_to_actual(batch_kwargs).detach().cpu()
|
||||
for batch_kwargs in loader
|
||||
])
|
||||
output = torch.cat([
|
||||
self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
|
||||
for batch_kwargs in loader
|
||||
])
|
||||
|
||||
data_tuples.append((actual, output, label))
|
||||
|
||||
@@ -219,6 +196,14 @@ class Plotter[B, K: EstimatorKwargs]:
|
||||
Note: transform samples in dataloader definitions beforehand if you
|
||||
want to change data
|
||||
|
||||
.. todo::
|
||||
|
||||
Merge in logic from general diagnostics, allowing collapse from
|
||||
either dim and transposing.
|
||||
|
||||
Later: multi-trial error bars, or at least the ability to pass that
|
||||
downstream
|
||||
|
||||
Parameters:
|
||||
row_size:
|
||||
col_size:
|
||||
@@ -462,3 +447,72 @@ class Plotter[B, K: EstimatorKwargs]:
|
||||
combine_dims=combine_dims,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
|
||||
def estimator_diagnostic(
|
||||
self,
|
||||
row_size: int | float = 2,
|
||||
col_size: int | float = 4,
|
||||
session_name: str | None = None,
|
||||
combine_groups: bool = False,
|
||||
combine_metrics: bool = False,
|
||||
transpose_layout: bool = False,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
):
|
||||
session_map = self.trainer._event_log
|
||||
session_name = session_name or next(iter(session_map))
|
||||
groups = session_map[session_name]
|
||||
num_metrics = len(groups[next(iter(groups))])
|
||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
|
||||
rows = 1 if combine_groups else len(groups)
|
||||
cols = 1 if combine_metrics else num_metrics
|
||||
if transpose_layout:
|
||||
rows, cols = cols, rows
|
||||
|
||||
fig, axes = self._create_subplots(
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
row_size=row_size,
|
||||
col_size=col_size,
|
||||
figure_kwargs=figure_kwargs,
|
||||
)
|
||||
if transpose_layout:
|
||||
axes = axes.T
|
||||
|
||||
for i, group_name in enumerate(groups):
|
||||
axes_row = axes[0 if combine_groups else i]
|
||||
group_metrics = groups[group_name]
|
||||
|
||||
for j, metric_name in enumerate(group_metrics):
|
||||
ax = axes_row[0 if combine_metrics else j]
|
||||
|
||||
metric_dict = group_metrics[metric_name]
|
||||
metric_data = np.array([
|
||||
(k, np.mean(v)) for k, v in metric_dict.items()
|
||||
])
|
||||
|
||||
if combine_groups and combine_metrics:
|
||||
label = f"{group_name}-{metric_name}"
|
||||
title_prefix = "all"
|
||||
elif combine_groups:
|
||||
label = group_name
|
||||
title_prefix = metric_name
|
||||
# elif combine_metrics:
|
||||
else:
|
||||
label = metric_name
|
||||
title_prefix = group_name
|
||||
# else:
|
||||
# label = ""
|
||||
# title_prefix = f"{group_name},{metric_name}"
|
||||
|
||||
ax.plot(
|
||||
metric_data[:, 0],
|
||||
metric_data[:, 1],
|
||||
label=label,
|
||||
# color=colors[j],
|
||||
)
|
||||
|
||||
ax.set_title(f"[{title_prefix}] Metrics over epochs")
|
||||
ax.set_xlabel("epoch", fontstyle='italic')
|
||||
ax.set_ylabel("value", fontstyle='italic')
|
||||
ax.legend()
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
"""
|
||||
Core interface for training ``Estimators`` with ``Datasets``
|
||||
|
||||
.. admonition:: Design of preview ``get_dataloaders()``
|
||||
|
||||
|
||||
Note how much this method is doing, and the positivity in letting that be
|
||||
more explicit elsewhere. The assignment of transforms to datasets before
|
||||
wrapping as loaders is chief among these items, alongside the balancing and
|
||||
splitting; I think those are hamfisted here to make it work with the old
|
||||
setup, but I generally it's not consistent with the name "get dataloaders"
|
||||
(i.e., and also balance and split and set transforms)
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -7,48 +17,41 @@ import time
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from copy import deepcopy
|
||||
from typing import Any, Self
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch import cuda, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from trainlib.dataset import BatchedDataset
|
||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||
from trainlib.transform import Transform
|
||||
from trainlib.utils.type import (
|
||||
SplitKwargs,
|
||||
LoaderKwargs,
|
||||
BalanceKwargs,
|
||||
)
|
||||
from trainlib.utils.map import nested_defaultdict
|
||||
from trainlib.dataloader import EstimatorDataLoader
|
||||
from trainlib.utils.module import ModelWrapper
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer[I, K: EstimatorKwargs]:
|
||||
class Trainer[I, Kw: EstimatorKwargs]:
|
||||
"""
|
||||
Training interface for optimizing parameters of ``Estimators`` with
|
||||
``Datasets``.
|
||||
|
||||
This class is generic to a dataset item type ``I`` and an estimator kwarg
|
||||
type ``K``. These are the two primary components ``Trainer`` objects need
|
||||
type ``Kw``. These are the two primary components ``Trainer`` objects need
|
||||
to coordinate: they ultimately rely on a provided map to ensure data items
|
||||
(type ``I``) from a dataset are appropriately routed as inputs to key
|
||||
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
|
||||
of type ``K``.
|
||||
of type ``Kw``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
estimator: Estimator[K],
|
||||
estimator: Estimator[Kw],
|
||||
device: str | None = None,
|
||||
chkpt_dir: str = "chkpt/",
|
||||
tblog_dir: str = "tblog/",
|
||||
@@ -93,6 +96,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
self.estimator = estimator
|
||||
self.estimator.to(self.device)
|
||||
self._event_log = nested_defaultdict(4, list)
|
||||
|
||||
self.chkpt_dir = Path(chkpt_dir).resolve()
|
||||
self.tblog_dir = Path(tblog_dir).resolve()
|
||||
@@ -105,56 +109,28 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
"""
|
||||
|
||||
self._epoch: int = 1
|
||||
self._summary = defaultdict(lambda: defaultdict(dict))
|
||||
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
|
||||
self._summary = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
self._val_loss = float("inf")
|
||||
self._best_val_loss = float("inf")
|
||||
self._conv_loss = float("inf")
|
||||
self._best_conv_loss = float("inf")
|
||||
self._stagnant_epochs = 0
|
||||
self._best_model_state_dict: dict[str, Any] = {}
|
||||
|
||||
def _train_epoch(
|
||||
self,
|
||||
train_loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
loader: EstimatorDataLoader[Any, Kw],
|
||||
optimizers: tuple[Optimizer, ...],
|
||||
max_grad_norm: float | None = None,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Train the estimator for a single epoch.
|
||||
|
||||
.. admonition:: On summary writers
|
||||
|
||||
Estimators can have several optimizers, and therefore can emit
|
||||
several losses. This is a fairly unique case, but it's needed when
|
||||
we want to optimize particular parameters in a particular order
|
||||
(as in multi-model architectures, e.g., GANs). Point being: we
|
||||
always iterate over optimizers/losses, even in the common case
|
||||
where there's just a single value, and we index collections across
|
||||
batches accordingly.
|
||||
|
||||
A few of the trackers, with the same size as the number of
|
||||
optimizers:
|
||||
|
||||
- ``train_loss_sums``: tracks loss sums across all batches for the
|
||||
epoch, used to update the loop preview text after each batch with
|
||||
the current average loss
|
||||
- ``train_loss_items``: collects current batch losses, recorded by
|
||||
the TB writer
|
||||
|
||||
If there are ``M`` optimizers/losses, we log ``M`` loss terms to
|
||||
the TB writer after each *batch* (not epoch). We could aggregate at
|
||||
an epoch level, but parameter updates take place after each batch,
|
||||
so large model changes can occur over the course of an epoch
|
||||
(whereas the model remains the same over the course batch evals).
|
||||
"""
|
||||
|
||||
loss_sums = []
|
||||
self.estimator.train()
|
||||
with tqdm(train_loader, unit="batch") as batches:
|
||||
for i, batch_data in enumerate(batches):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
losses = self.estimator.loss(**est_kwargs)
|
||||
with tqdm(loader, unit="batch") as batches:
|
||||
for i, batch_kwargs in enumerate(batches):
|
||||
losses = self.estimator.loss(**batch_kwargs)
|
||||
|
||||
for o_idx, (loss, optimizer) in enumerate(
|
||||
zip(losses, optimizers, strict=True)
|
||||
@@ -186,9 +162,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
def _eval_epoch(
|
||||
self,
|
||||
loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
loader_label: str,
|
||||
loader: EstimatorDataLoader[Any, Kw],
|
||||
label: str
|
||||
) -> list[float]:
|
||||
"""
|
||||
Perform and record validation scores for a single epoch.
|
||||
@@ -217,22 +192,21 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
loss_sums = []
|
||||
self.estimator.eval()
|
||||
with tqdm(loader, unit="batch") as batches:
|
||||
for i, batch_data in enumerate(batches):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
losses = self.estimator.loss(**est_kwargs)
|
||||
for i, batch_kwargs in enumerate(batches):
|
||||
losses = self.estimator.loss(**batch_kwargs)
|
||||
|
||||
# one-time logging
|
||||
if self._epoch == 0:
|
||||
self._writer.add_graph(
|
||||
ModelWrapper(self.estimator), est_kwargs
|
||||
ModelWrapper(self.estimator), batch_kwargs
|
||||
)
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
self._writer,
|
||||
step=self._epoch,
|
||||
group=loader_label,
|
||||
**est_kwargs
|
||||
group=label,
|
||||
**batch_kwargs
|
||||
)
|
||||
|
||||
loss_items = []
|
||||
@@ -250,21 +224,21 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
# log individual loss terms after each batch
|
||||
for o_idx, loss_item in enumerate(loss_items):
|
||||
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
|
||||
self._log_event(label, f"loss_{o_idx}", loss_item)
|
||||
|
||||
# log metrics for batch
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
estimator_metrics = self.estimator.metrics(**batch_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._log_event(loader_label, metric_name, metric_value)
|
||||
self._log_event(label, metric_name, metric_value)
|
||||
|
||||
return loss_sums
|
||||
|
||||
def _eval_loaders(
|
||||
self,
|
||||
loaders: list[DataLoader],
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
loader_labels: list[str],
|
||||
) -> dict[str, list[float]]:
|
||||
train_loader: EstimatorDataLoader[Any, Kw],
|
||||
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
|
||||
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
|
||||
) -> tuple[list[float], list[float] | None, *list[float]]:
|
||||
"""
|
||||
Evaluate estimator over each provided dataloader.
|
||||
|
||||
@@ -280,40 +254,50 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
information (just aggregated losses are provided here).
|
||||
"""
|
||||
|
||||
return {
|
||||
label: self._eval_epoch(loader, batch_estimator_map, label)
|
||||
for loader, label in zip(loaders, loader_labels, strict=True)
|
||||
}
|
||||
train_loss = self._eval_epoch(train_loader, "train")
|
||||
val_loss = self._eval_epoch(val_loader, "val") if val_loader else None
|
||||
|
||||
def train[B](
|
||||
aux_loaders = aux_loaders or []
|
||||
aux_losses = [
|
||||
self._eval_epoch(aux_loader, f"aux{i}")
|
||||
for i, aux_loader in enumerate(aux_loaders)
|
||||
]
|
||||
|
||||
return train_loss, val_loss, *aux_losses
|
||||
|
||||
def train(
|
||||
self,
|
||||
dataset: BatchedDataset[..., ..., I],
|
||||
batch_estimator_map: Callable[[B, Self], K],
|
||||
train_loader: EstimatorDataLoader[Any, Kw],
|
||||
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
|
||||
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
|
||||
*,
|
||||
lr: float = 1e-3,
|
||||
eps: float = 1e-8,
|
||||
max_grad_norm: float | None = None,
|
||||
max_epochs: int = 10,
|
||||
stop_after_epochs: int = 5,
|
||||
batch_size: int = 256,
|
||||
val_frac: float = 0.1,
|
||||
train_transform: Transform | None = None,
|
||||
val_transform: Transform | None = None,
|
||||
dataset_split_kwargs: SplitKwargs | None = None,
|
||||
dataset_balance_kwargs: BalanceKwargs | None = None,
|
||||
dataloader_kwargs: LoaderKwargs | None = None,
|
||||
summarize_every: int = 1,
|
||||
chkpt_every: int = 1,
|
||||
resume_latest: bool = False,
|
||||
session_name: str | None = None,
|
||||
summary_writer: SummaryWriter | None = None,
|
||||
aux_loaders: list[DataLoader] | None = None,
|
||||
aux_loader_labels: list[str] | None = None,
|
||||
) -> Estimator:
|
||||
"""
|
||||
TODO: consider making the dataloader ``collate_fn`` an explicit
|
||||
parameter with a type signature that reflects ``B``, connecting the
|
||||
``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader``
|
||||
in-house to allow a generic around ``B``
|
||||
.. todo::
|
||||
|
||||
- consider making the dataloader ``collate_fn`` an explicit
|
||||
parameter with a type signature that reflects ``B``, connecting
|
||||
the ``batch_estimator_map`` somewhere. Might also re-type a
|
||||
``DataLoader`` in-house to allow a generic around ``B``
|
||||
- Rework the validation specification. Accept something like a
|
||||
"validate_with" parameter, or perhaps just move entirely to
|
||||
accepting a dataloader list, label list. You might then also need
|
||||
a "train_with," and you could set up sensible defaults so you
|
||||
basically have the same interaction as now. The "problem" is you
|
||||
always need a train set, and there's some clearly dependent logic
|
||||
on a val set, but you don't *need* val, so this should be
|
||||
slightly reworked (and the more general, *probably* the better in
|
||||
this case, given I want to plug into the Plotter with possibly
|
||||
several purely eval sets over the model training lifetime).
|
||||
|
||||
Note: this method attempts to implement a general scheme for passing
|
||||
needed items to the estimator's loss function from the dataloader. The
|
||||
@@ -355,7 +339,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
This function should map from batches - which *may* be item
|
||||
shaped, i.e., have an ``I`` skeleton, even if stacked items may be
|
||||
different on the inside - into estimator keyword arguments (type
|
||||
``K``). Collation behavior from a DataLoader (which can be
|
||||
``Kw``). Collation behavior from a DataLoader (which can be
|
||||
customized) doesn't consistently yield a known type shape, however,
|
||||
so it's not appropriate to use ``I`` as the callable param type.
|
||||
|
||||
@@ -416,18 +400,13 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
dataset
|
||||
val_split_frac: fraction of dataset to use for validation
|
||||
chkpt_every: how often model checkpoints should be saved
|
||||
resume_latest: resume training from the latest available checkpoint
|
||||
in the `chkpt_dir`
|
||||
"""
|
||||
|
||||
logger.info("> Begin train loop:")
|
||||
logger.info(f"| > {lr=}")
|
||||
logger.info(f"| > {eps=}")
|
||||
logger.info(f"| > {max_epochs=}")
|
||||
logger.info(f"| > {batch_size=}")
|
||||
logger.info(f"| > {val_frac=}")
|
||||
logger.info(f"| > {chkpt_every=}")
|
||||
logger.info(f"| > {resume_latest=}")
|
||||
logger.info(f"| > with device: {self.device}")
|
||||
logger.info(f"| > core count: {os.cpu_count()}")
|
||||
|
||||
@@ -435,25 +414,9 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
tblog_path = Path(self.tblog_dir, self._session_name)
|
||||
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
||||
|
||||
aux_loaders = aux_loaders or []
|
||||
aux_loader_labels = aux_loader_labels or []
|
||||
|
||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||
train_loader, val_loader = self.get_dataloaders(
|
||||
dataset,
|
||||
batch_size,
|
||||
val_frac=val_frac,
|
||||
train_transform=train_transform,
|
||||
val_transform=val_transform,
|
||||
dataset_split_kwargs=dataset_split_kwargs,
|
||||
dataset_balance_kwargs=dataset_balance_kwargs,
|
||||
dataloader_kwargs=dataloader_kwargs,
|
||||
)
|
||||
loaders = [train_loader, val_loader, *aux_loaders]
|
||||
loader_labels = ["train", "val", *aux_loader_labels]
|
||||
|
||||
# evaluate model on dataloaders once before training starts
|
||||
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
|
||||
self._eval_loaders(train_loader, val_loader, aux_loaders)
|
||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||
|
||||
while self._epoch <= max_epochs and not self._converged(
|
||||
self._epoch, stop_after_epochs
|
||||
@@ -464,22 +427,14 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
print(f"Stagnant epochs {stag_frac}...")
|
||||
|
||||
epoch_start_time = time.time()
|
||||
self._train_epoch(
|
||||
train_loader,
|
||||
batch_estimator_map,
|
||||
optimizers,
|
||||
max_grad_norm,
|
||||
)
|
||||
self._train_epoch(train_loader, optimizers, max_grad_norm)
|
||||
epoch_end_time = time.time() - epoch_start_time
|
||||
self._log_event("train", "epoch_duration", epoch_end_time)
|
||||
|
||||
loss_sum_map = self._eval_loaders(
|
||||
loaders,
|
||||
batch_estimator_map,
|
||||
loader_labels,
|
||||
train_loss, val_loss, _ = self._eval_loaders(
|
||||
train_loader, val_loader, aux_loaders
|
||||
)
|
||||
val_loss_sums = loss_sum_map["val"]
|
||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||
self._conv_loss = sum(val_loss) if val_loss else sum(train_loss)
|
||||
|
||||
if self._epoch % summarize_every == 0:
|
||||
self._summarize()
|
||||
@@ -492,8 +447,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
||||
converged = False
|
||||
|
||||
if epoch == 1 or self._val_loss < self._best_val_loss:
|
||||
self._best_val_loss = self._val_loss
|
||||
if epoch == 1 or self._conv_loss < self._best_val_loss:
|
||||
self._best_val_loss = self._conv_loss
|
||||
self._stagnant_epochs = 0
|
||||
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
||||
else:
|
||||
@@ -505,110 +460,24 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
return converged
|
||||
|
||||
@staticmethod
|
||||
def get_dataloaders(
|
||||
dataset: BatchedDataset,
|
||||
batch_size: int,
|
||||
val_frac: float = 0.1,
|
||||
train_transform: Transform | None = None,
|
||||
val_transform: Transform | None = None,
|
||||
dataset_split_kwargs: SplitKwargs | None = None,
|
||||
dataset_balance_kwargs: BalanceKwargs | None = None,
|
||||
dataloader_kwargs: LoaderKwargs | None = None,
|
||||
) -> tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create training and validation dataloaders for the provided dataset.
|
||||
|
||||
.. todo::
|
||||
|
||||
Decide on policy for empty val dataloaders
|
||||
"""
|
||||
|
||||
if dataset_split_kwargs is None:
|
||||
dataset_split_kwargs = {}
|
||||
|
||||
if dataset_balance_kwargs is not None:
|
||||
dataset.balance(**dataset_balance_kwargs)
|
||||
|
||||
if val_frac <= 0:
|
||||
dataset.post_transform = train_transform
|
||||
train_loader_kwargs: LoaderKwargs = {
|
||||
"batch_size": min(batch_size, len(dataset)),
|
||||
"num_workers": 0,
|
||||
"shuffle": True,
|
||||
}
|
||||
if dataloader_kwargs is not None:
|
||||
train_loader_kwargs: LoaderKwargs = {
|
||||
**train_loader_kwargs,
|
||||
**dataloader_kwargs
|
||||
}
|
||||
|
||||
return (
|
||||
DataLoader(dataset, **train_loader_kwargs),
|
||||
DataLoader(Dataset())
|
||||
)
|
||||
|
||||
train_dataset, val_dataset = dataset.split(
|
||||
[1 - val_frac, val_frac],
|
||||
**dataset_split_kwargs,
|
||||
)
|
||||
|
||||
# Dataset.split() returns light Subset objects of shallow copies of the
|
||||
# underlying dataset; can change the transform attribute of both splits
|
||||
# w/o overwriting
|
||||
train_dataset.post_transform = train_transform
|
||||
val_dataset.post_transform = val_transform
|
||||
|
||||
train_loader_kwargs: LoaderKwargs = {
|
||||
"batch_size": min(batch_size, len(train_dataset)),
|
||||
"num_workers": 0,
|
||||
"shuffle": True,
|
||||
}
|
||||
val_loader_kwargs: LoaderKwargs = {
|
||||
"batch_size": min(batch_size, len(val_dataset)),
|
||||
"num_workers": 0,
|
||||
"shuffle": True, # shuffle to prevent homogeneous val batches
|
||||
}
|
||||
|
||||
if dataloader_kwargs is not None:
|
||||
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
|
||||
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
|
||||
|
||||
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
|
||||
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
def _summarize(self) -> None:
|
||||
"""
|
||||
Flush the training summary to the TensorBoard summary writer.
|
||||
|
||||
.. note:: Possibly undesirable behavior
|
||||
|
||||
Currently, this method aggregates metrics for the epoch summary
|
||||
across all logged items *in between summarize calls*. For instance,
|
||||
if I'm logging every 10 epochs, the stats at epoch=10 are actually
|
||||
averages from epochs 1-10.
|
||||
Flush the training summary to the TensorBoard summary writer and print
|
||||
metrics for the current epoch.
|
||||
"""
|
||||
|
||||
epoch_values = defaultdict(lambda: defaultdict(list))
|
||||
for group, records in self._summary.items():
|
||||
for name, steps in records.items():
|
||||
for step, value in steps.items():
|
||||
self._writer.add_scalar(f"{group}-{name}", value, step)
|
||||
if step == self._epoch:
|
||||
epoch_values[group][name].append(value)
|
||||
|
||||
print(f"==== Epoch [{self._epoch}] summary ====")
|
||||
for group, records in epoch_values.items():
|
||||
for name, values in records.items():
|
||||
mean_value = torch.tensor(values).mean().item()
|
||||
print(
|
||||
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
|
||||
)
|
||||
for (group, name), epoch_map in self._summary.items():
|
||||
for epoch, values in epoch_map.items():
|
||||
mean = torch.tensor(values).mean().item()
|
||||
self._writer.add_scalar(f"{group}-{name}", mean, epoch)
|
||||
if epoch == self._epoch:
|
||||
print(
|
||||
f"> ({len(values)}) [{group}] {name} :: {mean:.2f}"
|
||||
)
|
||||
|
||||
self._writer.flush()
|
||||
self._summary = defaultdict(lambda: defaultdict(dict))
|
||||
self._summary = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
def _get_optimizer_parameters(
|
||||
self,
|
||||
@@ -622,33 +491,10 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
]
|
||||
|
||||
def _log_event(self, group: str, name: str, value: float) -> None:
|
||||
self._summary[group][name][self._epoch] = value
|
||||
self._event_log[self._session_name][group][name][self._epoch] = value
|
||||
session, epoch = self._session_name, self._epoch
|
||||
|
||||
def get_batch_outputs[B](
|
||||
self,
|
||||
batch: B,
|
||||
batch_estimator_map: Callable[[B, Self], K],
|
||||
) -> Tensor:
|
||||
self.estimator.eval()
|
||||
|
||||
est_kwargs = batch_estimator_map(batch, self)
|
||||
output = self.estimator(**est_kwargs)[0]
|
||||
output = output.detach().cpu()
|
||||
|
||||
return output
|
||||
|
||||
def get_batch_metrics[B](
|
||||
self,
|
||||
batch: B,
|
||||
batch_estimator_map: Callable[[B, Self], K],
|
||||
) -> dict[str, float]:
|
||||
self.estimator.eval()
|
||||
|
||||
est_kwargs = batch_estimator_map(batch, self)
|
||||
metrics = self.estimator.metrics(**est_kwargs)
|
||||
|
||||
return metrics
|
||||
self._summary[group, name][epoch].append(value)
|
||||
self._event_log[session][group][name][epoch].append(value)
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""
|
||||
|
||||
10
trainlib/utils/map.py
Normal file
10
trainlib/utils/map.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def nested_defaultdict(
|
||||
depth: int,
|
||||
final: type = dict,
|
||||
) -> defaultdict:
|
||||
if depth == 1:
|
||||
return defaultdict(final)
|
||||
return defaultdict(lambda: nested_defaultdict(depth - 1, final))
|
||||
@@ -11,7 +11,7 @@ from trainlib.dataset import BatchedDataset
|
||||
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
|
||||
|
||||
class LoaderKwargs(TypedDict, total=False):
|
||||
batch_size: int
|
||||
batch_size: int | None
|
||||
shuffle: bool
|
||||
sampler: Sampler | Iterable | None
|
||||
batch_sampler: Sampler[list] | Iterable[list] | None
|
||||
|
||||
Reference in New Issue
Block a user