diff --git a/trainlib/dataloader.py b/trainlib/dataloader.py new file mode 100644 index 0000000..0d8d5ac --- /dev/null +++ b/trainlib/dataloader.py @@ -0,0 +1,102 @@ +""" +This class took me a long time to really settle into. It's a connector, and it +feels redundant in many ways, so I've nearly deleted it several times while +talking through the design. But in total, I think it serves a clear purpose. +Reasons: + +- Need a typed dataloader, even if I know the type of my attached transform +- Need a new scope that uses the same base dataset without interfering with the + transform attribute; a design that sets or relies on that is subject to + conflict + +- Why not just use vanilla DataLoaders? + + I'd like to, but the two reasons above make it clear why this is challenging: + I don't get static checks on the structures returned during iteration, and + while you can control ad hoc data transforms via dataset ``post_transforms``, + things can get messy if you need to do that for many transforms using the + same dataset (without copying). Simplest way around this is just a new scope + with the same underlying dataset instance and a transform wrapper around the + iterator; no interference with object attributes. + +This is really just meant as the minimum viable logic needed to accomplish the +above - it's a very lightweight wrapper on the base ``DataLoader`` object. +There's an explicit type upper bound ``Kw: EstimatorKwargs``, but it is +otherwise a completely general transform over dataloader batches, highlighting +that it's *mostly* here to place nice with type checks. +""" + +from typing import Unpack +from collections.abc import Iterator + +from torch.utils.data import DataLoader + +from trainlib.dataset import BatchedDataset +from trainlib.estimator import EstimatorKwargs +from trainlib.utils.type import LoaderKwargs + + +class EstimatorDataLoader[B, Kw: EstimatorKwargs]: + """ + Data loaders for estimators. + + This class exists to connect batched data from datasets to the expected + representation for estimator methods. Datasets may be developed + independently from a given model structures, and models should be trainable + under any such data. We need a way to ensure the batched groups of items we + get from dataloaders match on a type level, i.e., can be reshaped into the + expected ``Kw`` signature. + + Note: batch structure ``B`` cannot be directly inferred from type variables + exposed by ``BatchedDatasets`` (namely ``R`` and ``I``). What's returned by + a data loader wrapping any such dataset can be arbitrary (depending on the + ``collate_fn``), with default behavior being fairly consistent under nested + collections but challenging to accurately type. + + .. todo:: + + To log (have changed for Trainer): + + - New compact eval pipeline for train/val/auxiliary dataloaders. + Somewhat annoying logic, but handled consistently + - Convergence tracker will dynamically use training loss (early + stopping) when a validation set isn't provided. Same mechanics for + stagnant epochs (although early stopping is generally a little more + nuanced, having a rate-based stopper, b/c train loss generally quite + monotonic). So that's to be updated, plus room for possible model + selection strategies later. + - Logging happens at each batch, but we append to an epoch-indexed list + and later average. There was a bug in the last round of testing that + I didn't pick up where I was just overwriting summaries using the + last seen batch. + - Reworked general dataset/dataloader handling for main train loop, now + accepting objects of this class to bridge estimator and dataset + communication. This cleans up the batch mapping model. + - TODO: implement a version of this that canonically works with the + device passing plus EstimatorKwargs input; this is the last fuzzy bit + I think. + """ + + def __init__( + self, + dataset: BatchedDataset, + **dataloader_kwargs: Unpack[LoaderKwargs], + ) -> None: + self._dataloader = DataLoader(dataset, **dataloader_kwargs) + + def batch_to_est_kwargs(self, batch_data: B) -> Kw: + """ + .. note:: + + Even if we have a concrete shape for the output kwarg dict for base + estimators (requiring a tensor "inputs" attribute), we don't + presuppose how a given batch object will map into this dict + structure. + + return EstimatorKwargs({"inputs":0}) + """ + + raise NotImplementedError + + def __iter__(self) -> Iterator[Kw]: + return map(self.batch_to_est_kwargs, self._dataloader) diff --git a/trainlib/diagnostic.py b/trainlib/diagnostic.py deleted file mode 100644 index e69de29..0000000 diff --git a/trainlib/plotter.py b/trainlib/plotter.py index 1446c13..f696606 100644 --- a/trainlib/plotter.py +++ b/trainlib/plotter.py @@ -1,21 +1,21 @@ -from functools import partial +from typing import Any from collections.abc import Callable import numpy as np import torch import matplotlib.pyplot as plt from torch import Tensor -from torch.utils.data import DataLoader from trainlib.trainer import Trainer from trainlib.estimator import EstimatorKwargs +from trainlib.dataloader import EstimatorDataLoader from trainlib.utils.type import AxesArray, SubplotsKwargs type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None] type ContextFn = Callable[[plt.Axes, str], None] -class Plotter[B, K: EstimatorKwargs]: +class Plotter[Kw: EstimatorKwargs]: """ TODOs: @@ -30,10 +30,9 @@ class Plotter[B, K: EstimatorKwargs]: def __init__( self, - trainer: Trainer[..., K], - dataloaders: list[DataLoader], - batch_estimator_map: Callable[[B, Trainer], ...], - estimator_to_output_map: Callable[[K], ...], + trainer: Trainer[Any, Kw], + dataloaders: list[EstimatorDataLoader[Any, Kw]], + kw_to_actual: Callable[[Kw], Tensor], dataloader_labels: list[str] | None = None, ) -> None: self.trainer = trainer @@ -41,47 +40,21 @@ class Plotter[B, K: EstimatorKwargs]: self.dataloader_labels = ( dataloader_labels or list(map(str, range(1, len(dataloaders)+1))) ) - self.batch_estimator_map = batch_estimator_map - self.estimator_to_output_map = estimator_to_output_map + self.kw_to_actual = kw_to_actual self._outputs: list[list[Tensor]] | None = None self._metrics: list[list[dict[str, float]]] | None = None - self._batch_outputs_fn = partial( - self.trainer.get_batch_outputs, - batch_estimator_map=batch_estimator_map - ) - self._batch_metrics_fn = partial( - self.trainer.get_batch_metrics, - batch_estimator_map=batch_estimator_map - ) - self._data_tuples = None - @property - def outputs(self) -> list[list[Tensor]]: - if self._outputs is None: - self._outputs = [ - list(map(self._batch_outputs_fn, loader)) - for loader in self.dataloaders - ] - return self._outputs - - @property - def metrics(self) -> list[list[dict[str, float]]]: - if self._metrics is None: - self._metrics = [ - list(map(self._batch_metrics_fn, loader)) - for loader in self.dataloaders - ] - return self._metrics - @property def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]: """ Produce data items; to be cached. Zip later with axes """ + self.trainer.estimator.eval() + if self._data_tuples is not None: return self._data_tuples @@ -89,10 +62,14 @@ class Plotter[B, K: EstimatorKwargs]: for i, loader in enumerate(self.dataloaders): label = self.dataloader_labels[i] - batch = next(iter(loader)) - est_kwargs = self.batch_estimator_map(batch, self.trainer) - actual = self.estimator_to_output_map(est_kwargs) - output = self._batch_outputs_fn(batch) + actual = torch.cat([ + self.kw_to_actual(batch_kwargs).detach().cpu() + for batch_kwargs in loader + ]) + output = torch.cat([ + self.trainer.estimator(**batch_kwargs)[0].detach().cpu() + for batch_kwargs in loader + ]) data_tuples.append((actual, output, label)) @@ -219,6 +196,14 @@ class Plotter[B, K: EstimatorKwargs]: Note: transform samples in dataloader definitions beforehand if you want to change data + .. todo:: + + Merge in logic from general diagnostics, allowing collapse from + either dim and transposing. + + Later: multi-trial error bars, or at least the ability to pass that + downstream + Parameters: row_size: col_size: @@ -462,3 +447,72 @@ class Plotter[B, K: EstimatorKwargs]: combine_dims=combine_dims, figure_kwargs=figure_kwargs, ) + + def estimator_diagnostic( + self, + row_size: int | float = 2, + col_size: int | float = 4, + session_name: str | None = None, + combine_groups: bool = False, + combine_metrics: bool = False, + transpose_layout: bool = False, + figure_kwargs: SubplotsKwargs | None = None, + ): + session_map = self.trainer._event_log + session_name = session_name or next(iter(session_map)) + groups = session_map[session_name] + num_metrics = len(groups[next(iter(groups))]) + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + + rows = 1 if combine_groups else len(groups) + cols = 1 if combine_metrics else num_metrics + if transpose_layout: + rows, cols = cols, rows + + fig, axes = self._create_subplots( + rows=rows, + cols=cols, + row_size=row_size, + col_size=col_size, + figure_kwargs=figure_kwargs, + ) + if transpose_layout: + axes = axes.T + + for i, group_name in enumerate(groups): + axes_row = axes[0 if combine_groups else i] + group_metrics = groups[group_name] + + for j, metric_name in enumerate(group_metrics): + ax = axes_row[0 if combine_metrics else j] + + metric_dict = group_metrics[metric_name] + metric_data = np.array([ + (k, np.mean(v)) for k, v in metric_dict.items() + ]) + + if combine_groups and combine_metrics: + label = f"{group_name}-{metric_name}" + title_prefix = "all" + elif combine_groups: + label = group_name + title_prefix = metric_name + # elif combine_metrics: + else: + label = metric_name + title_prefix = group_name + # else: + # label = "" + # title_prefix = f"{group_name},{metric_name}" + + ax.plot( + metric_data[:, 0], + metric_data[:, 1], + label=label, + # color=colors[j], + ) + + ax.set_title(f"[{title_prefix}] Metrics over epochs") + ax.set_xlabel("epoch", fontstyle='italic') + ax.set_ylabel("value", fontstyle='italic') + ax.legend() diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 16e6006..0c00eeb 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -1,5 +1,15 @@ """ Core interface for training ``Estimators`` with ``Datasets`` + +.. admonition:: Design of preview ``get_dataloaders()`` + + + Note how much this method is doing, and the positivity in letting that be + more explicit elsewhere. The assignment of transforms to datasets before + wrapping as loaders is chief among these items, alongside the balancing and + splitting; I think those are hamfisted here to make it work with the old + setup, but I generally it's not consistent with the name "get dataloaders" + (i.e., and also balance and split and set transforms) """ import os @@ -7,48 +17,41 @@ import time import logging from io import BytesIO from copy import deepcopy -from typing import Any, Self +from typing import Any from pathlib import Path from collections import defaultdict -from collections.abc import Callable import torch from tqdm import tqdm from torch import cuda, Tensor from torch.optim import Optimizer from torch.nn.utils import clip_grad_norm_ -from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter -from trainlib.dataset import BatchedDataset from trainlib.estimator import Estimator, EstimatorKwargs -from trainlib.transform import Transform -from trainlib.utils.type import ( - SplitKwargs, - LoaderKwargs, - BalanceKwargs, -) +from trainlib.utils.map import nested_defaultdict +from trainlib.dataloader import EstimatorDataLoader from trainlib.utils.module import ModelWrapper logger: logging.Logger = logging.getLogger(__name__) -class Trainer[I, K: EstimatorKwargs]: +class Trainer[I, Kw: EstimatorKwargs]: """ Training interface for optimizing parameters of ``Estimators`` with ``Datasets``. This class is generic to a dataset item type ``I`` and an estimator kwarg - type ``K``. These are the two primary components ``Trainer`` objects need + type ``Kw``. These are the two primary components ``Trainer`` objects need to coordinate: they ultimately rely on a provided map to ensure data items (type ``I``) from a dataset are appropriately routed as inputs to key estimator methods (like ``forward()`` and ``loss()``), which accept inputs - of type ``K``. + of type ``Kw``. """ def __init__( self, - estimator: Estimator[K], + estimator: Estimator[Kw], device: str | None = None, chkpt_dir: str = "chkpt/", tblog_dir: str = "tblog/", @@ -93,6 +96,7 @@ class Trainer[I, K: EstimatorKwargs]: self.estimator = estimator self.estimator.to(self.device) + self._event_log = nested_defaultdict(4, list) self.chkpt_dir = Path(chkpt_dir).resolve() self.tblog_dir = Path(tblog_dir).resolve() @@ -105,56 +109,28 @@ class Trainer[I, K: EstimatorKwargs]: """ self._epoch: int = 1 - self._summary = defaultdict(lambda: defaultdict(dict)) - self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + self._summary = defaultdict(lambda: defaultdict(list)) - self._val_loss = float("inf") - self._best_val_loss = float("inf") + self._conv_loss = float("inf") + self._best_conv_loss = float("inf") self._stagnant_epochs = 0 self._best_model_state_dict: dict[str, Any] = {} def _train_epoch( self, - train_loader: DataLoader, - batch_estimator_map: Callable[[I, Self], K], + loader: EstimatorDataLoader[Any, Kw], optimizers: tuple[Optimizer, ...], max_grad_norm: float | None = None, ) -> list[float]: """ Train the estimator for a single epoch. - - .. admonition:: On summary writers - - Estimators can have several optimizers, and therefore can emit - several losses. This is a fairly unique case, but it's needed when - we want to optimize particular parameters in a particular order - (as in multi-model architectures, e.g., GANs). Point being: we - always iterate over optimizers/losses, even in the common case - where there's just a single value, and we index collections across - batches accordingly. - - A few of the trackers, with the same size as the number of - optimizers: - - - ``train_loss_sums``: tracks loss sums across all batches for the - epoch, used to update the loop preview text after each batch with - the current average loss - - ``train_loss_items``: collects current batch losses, recorded by - the TB writer - - If there are ``M`` optimizers/losses, we log ``M`` loss terms to - the TB writer after each *batch* (not epoch). We could aggregate at - an epoch level, but parameter updates take place after each batch, - so large model changes can occur over the course of an epoch - (whereas the model remains the same over the course batch evals). """ loss_sums = [] self.estimator.train() - with tqdm(train_loader, unit="batch") as batches: - for i, batch_data in enumerate(batches): - est_kwargs = batch_estimator_map(batch_data, self) - losses = self.estimator.loss(**est_kwargs) + with tqdm(loader, unit="batch") as batches: + for i, batch_kwargs in enumerate(batches): + losses = self.estimator.loss(**batch_kwargs) for o_idx, (loss, optimizer) in enumerate( zip(losses, optimizers, strict=True) @@ -186,9 +162,8 @@ class Trainer[I, K: EstimatorKwargs]: def _eval_epoch( self, - loader: DataLoader, - batch_estimator_map: Callable[[I, Self], K], - loader_label: str, + loader: EstimatorDataLoader[Any, Kw], + label: str ) -> list[float]: """ Perform and record validation scores for a single epoch. @@ -217,22 +192,21 @@ class Trainer[I, K: EstimatorKwargs]: loss_sums = [] self.estimator.eval() with tqdm(loader, unit="batch") as batches: - for i, batch_data in enumerate(batches): - est_kwargs = batch_estimator_map(batch_data, self) - losses = self.estimator.loss(**est_kwargs) + for i, batch_kwargs in enumerate(batches): + losses = self.estimator.loss(**batch_kwargs) # one-time logging if self._epoch == 0: self._writer.add_graph( - ModelWrapper(self.estimator), est_kwargs + ModelWrapper(self.estimator), batch_kwargs ) # once-per-epoch logging if i == 0: self.estimator.epoch_write( self._writer, step=self._epoch, - group=loader_label, - **est_kwargs + group=label, + **batch_kwargs ) loss_items = [] @@ -250,21 +224,21 @@ class Trainer[I, K: EstimatorKwargs]: # log individual loss terms after each batch for o_idx, loss_item in enumerate(loss_items): - self._log_event(loader_label, f"loss_{o_idx}", loss_item) + self._log_event(label, f"loss_{o_idx}", loss_item) # log metrics for batch - estimator_metrics = self.estimator.metrics(**est_kwargs) + estimator_metrics = self.estimator.metrics(**batch_kwargs) for metric_name, metric_value in estimator_metrics.items(): - self._log_event(loader_label, metric_name, metric_value) + self._log_event(label, metric_name, metric_value) return loss_sums def _eval_loaders( self, - loaders: list[DataLoader], - batch_estimator_map: Callable[[I, Self], K], - loader_labels: list[str], - ) -> dict[str, list[float]]: + train_loader: EstimatorDataLoader[Any, Kw], + val_loader: EstimatorDataLoader[Any, Kw] | None = None, + aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None, + ) -> tuple[list[float], list[float] | None, *list[float]]: """ Evaluate estimator over each provided dataloader. @@ -280,40 +254,50 @@ class Trainer[I, K: EstimatorKwargs]: information (just aggregated losses are provided here). """ - return { - label: self._eval_epoch(loader, batch_estimator_map, label) - for loader, label in zip(loaders, loader_labels, strict=True) - } + train_loss = self._eval_epoch(train_loader, "train") + val_loss = self._eval_epoch(val_loader, "val") if val_loader else None - def train[B]( + aux_loaders = aux_loaders or [] + aux_losses = [ + self._eval_epoch(aux_loader, f"aux{i}") + for i, aux_loader in enumerate(aux_loaders) + ] + + return train_loss, val_loss, *aux_losses + + def train( self, - dataset: BatchedDataset[..., ..., I], - batch_estimator_map: Callable[[B, Self], K], + train_loader: EstimatorDataLoader[Any, Kw], + val_loader: EstimatorDataLoader[Any, Kw] | None = None, + aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None, + *, lr: float = 1e-3, eps: float = 1e-8, max_grad_norm: float | None = None, max_epochs: int = 10, stop_after_epochs: int = 5, - batch_size: int = 256, - val_frac: float = 0.1, - train_transform: Transform | None = None, - val_transform: Transform | None = None, - dataset_split_kwargs: SplitKwargs | None = None, - dataset_balance_kwargs: BalanceKwargs | None = None, - dataloader_kwargs: LoaderKwargs | None = None, summarize_every: int = 1, chkpt_every: int = 1, - resume_latest: bool = False, session_name: str | None = None, summary_writer: SummaryWriter | None = None, - aux_loaders: list[DataLoader] | None = None, - aux_loader_labels: list[str] | None = None, ) -> Estimator: """ - TODO: consider making the dataloader ``collate_fn`` an explicit - parameter with a type signature that reflects ``B``, connecting the - ``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader`` - in-house to allow a generic around ``B`` + .. todo:: + + - consider making the dataloader ``collate_fn`` an explicit + parameter with a type signature that reflects ``B``, connecting + the ``batch_estimator_map`` somewhere. Might also re-type a + ``DataLoader`` in-house to allow a generic around ``B`` + - Rework the validation specification. Accept something like a + "validate_with" parameter, or perhaps just move entirely to + accepting a dataloader list, label list. You might then also need + a "train_with," and you could set up sensible defaults so you + basically have the same interaction as now. The "problem" is you + always need a train set, and there's some clearly dependent logic + on a val set, but you don't *need* val, so this should be + slightly reworked (and the more general, *probably* the better in + this case, given I want to plug into the Plotter with possibly + several purely eval sets over the model training lifetime). Note: this method attempts to implement a general scheme for passing needed items to the estimator's loss function from the dataloader. The @@ -355,7 +339,7 @@ class Trainer[I, K: EstimatorKwargs]: This function should map from batches - which *may* be item shaped, i.e., have an ``I`` skeleton, even if stacked items may be different on the inside - into estimator keyword arguments (type - ``K``). Collation behavior from a DataLoader (which can be + ``Kw``). Collation behavior from a DataLoader (which can be customized) doesn't consistently yield a known type shape, however, so it's not appropriate to use ``I`` as the callable param type. @@ -416,18 +400,13 @@ class Trainer[I, K: EstimatorKwargs]: dataset val_split_frac: fraction of dataset to use for validation chkpt_every: how often model checkpoints should be saved - resume_latest: resume training from the latest available checkpoint - in the `chkpt_dir` """ logger.info("> Begin train loop:") logger.info(f"| > {lr=}") logger.info(f"| > {eps=}") logger.info(f"| > {max_epochs=}") - logger.info(f"| > {batch_size=}") - logger.info(f"| > {val_frac=}") logger.info(f"| > {chkpt_every=}") - logger.info(f"| > {resume_latest=}") logger.info(f"| > with device: {self.device}") logger.info(f"| > core count: {os.cpu_count()}") @@ -435,25 +414,9 @@ class Trainer[I, K: EstimatorKwargs]: tblog_path = Path(self.tblog_dir, self._session_name) self._writer = summary_writer or SummaryWriter(f"{tblog_path}") - aux_loaders = aux_loaders or [] - aux_loader_labels = aux_loader_labels or [] - - optimizers = self.estimator.optimizers(lr=lr, eps=eps) - train_loader, val_loader = self.get_dataloaders( - dataset, - batch_size, - val_frac=val_frac, - train_transform=train_transform, - val_transform=val_transform, - dataset_split_kwargs=dataset_split_kwargs, - dataset_balance_kwargs=dataset_balance_kwargs, - dataloader_kwargs=dataloader_kwargs, - ) - loaders = [train_loader, val_loader, *aux_loaders] - loader_labels = ["train", "val", *aux_loader_labels] - # evaluate model on dataloaders once before training starts - self._eval_loaders(loaders, batch_estimator_map, loader_labels) + self._eval_loaders(train_loader, val_loader, aux_loaders) + optimizers = self.estimator.optimizers(lr=lr, eps=eps) while self._epoch <= max_epochs and not self._converged( self._epoch, stop_after_epochs @@ -464,22 +427,14 @@ class Trainer[I, K: EstimatorKwargs]: print(f"Stagnant epochs {stag_frac}...") epoch_start_time = time.time() - self._train_epoch( - train_loader, - batch_estimator_map, - optimizers, - max_grad_norm, - ) + self._train_epoch(train_loader, optimizers, max_grad_norm) epoch_end_time = time.time() - epoch_start_time self._log_event("train", "epoch_duration", epoch_end_time) - loss_sum_map = self._eval_loaders( - loaders, - batch_estimator_map, - loader_labels, + train_loss, val_loss, _ = self._eval_loaders( + train_loader, val_loader, aux_loaders ) - val_loss_sums = loss_sum_map["val"] - self._val_loss = sum(val_loss_sums) / len(val_loader) + self._conv_loss = sum(val_loss) if val_loss else sum(train_loss) if self._epoch % summarize_every == 0: self._summarize() @@ -492,8 +447,8 @@ class Trainer[I, K: EstimatorKwargs]: def _converged(self, epoch: int, stop_after_epochs: int) -> bool: converged = False - if epoch == 1 or self._val_loss < self._best_val_loss: - self._best_val_loss = self._val_loss + if epoch == 1 or self._conv_loss < self._best_val_loss: + self._best_val_loss = self._conv_loss self._stagnant_epochs = 0 self._best_model_state_dict = deepcopy(self.estimator.state_dict()) else: @@ -505,110 +460,24 @@ class Trainer[I, K: EstimatorKwargs]: return converged - @staticmethod - def get_dataloaders( - dataset: BatchedDataset, - batch_size: int, - val_frac: float = 0.1, - train_transform: Transform | None = None, - val_transform: Transform | None = None, - dataset_split_kwargs: SplitKwargs | None = None, - dataset_balance_kwargs: BalanceKwargs | None = None, - dataloader_kwargs: LoaderKwargs | None = None, - ) -> tuple[DataLoader, DataLoader]: - """ - Create training and validation dataloaders for the provided dataset. - - .. todo:: - - Decide on policy for empty val dataloaders - """ - - if dataset_split_kwargs is None: - dataset_split_kwargs = {} - - if dataset_balance_kwargs is not None: - dataset.balance(**dataset_balance_kwargs) - - if val_frac <= 0: - dataset.post_transform = train_transform - train_loader_kwargs: LoaderKwargs = { - "batch_size": min(batch_size, len(dataset)), - "num_workers": 0, - "shuffle": True, - } - if dataloader_kwargs is not None: - train_loader_kwargs: LoaderKwargs = { - **train_loader_kwargs, - **dataloader_kwargs - } - - return ( - DataLoader(dataset, **train_loader_kwargs), - DataLoader(Dataset()) - ) - - train_dataset, val_dataset = dataset.split( - [1 - val_frac, val_frac], - **dataset_split_kwargs, - ) - - # Dataset.split() returns light Subset objects of shallow copies of the - # underlying dataset; can change the transform attribute of both splits - # w/o overwriting - train_dataset.post_transform = train_transform - val_dataset.post_transform = val_transform - - train_loader_kwargs: LoaderKwargs = { - "batch_size": min(batch_size, len(train_dataset)), - "num_workers": 0, - "shuffle": True, - } - val_loader_kwargs: LoaderKwargs = { - "batch_size": min(batch_size, len(val_dataset)), - "num_workers": 0, - "shuffle": True, # shuffle to prevent homogeneous val batches - } - - if dataloader_kwargs is not None: - train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs} - val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs} - - train_loader = DataLoader(train_dataset, **train_loader_kwargs) - val_loader = DataLoader(val_dataset, **val_loader_kwargs) - - return train_loader, val_loader - def _summarize(self) -> None: """ - Flush the training summary to the TensorBoard summary writer. - - .. note:: Possibly undesirable behavior - - Currently, this method aggregates metrics for the epoch summary - across all logged items *in between summarize calls*. For instance, - if I'm logging every 10 epochs, the stats at epoch=10 are actually - averages from epochs 1-10. + Flush the training summary to the TensorBoard summary writer and print + metrics for the current epoch. """ - epoch_values = defaultdict(lambda: defaultdict(list)) - for group, records in self._summary.items(): - for name, steps in records.items(): - for step, value in steps.items(): - self._writer.add_scalar(f"{group}-{name}", value, step) - if step == self._epoch: - epoch_values[group][name].append(value) - print(f"==== Epoch [{self._epoch}] summary ====") - for group, records in epoch_values.items(): - for name, values in records.items(): - mean_value = torch.tensor(values).mean().item() - print( - f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}" - ) + for (group, name), epoch_map in self._summary.items(): + for epoch, values in epoch_map.items(): + mean = torch.tensor(values).mean().item() + self._writer.add_scalar(f"{group}-{name}", mean, epoch) + if epoch == self._epoch: + print( + f"> ({len(values)}) [{group}] {name} :: {mean:.2f}" + ) self._writer.flush() - self._summary = defaultdict(lambda: defaultdict(dict)) + self._summary = defaultdict(lambda: defaultdict(list)) def _get_optimizer_parameters( self, @@ -622,33 +491,10 @@ class Trainer[I, K: EstimatorKwargs]: ] def _log_event(self, group: str, name: str, value: float) -> None: - self._summary[group][name][self._epoch] = value - self._event_log[self._session_name][group][name][self._epoch] = value + session, epoch = self._session_name, self._epoch - def get_batch_outputs[B]( - self, - batch: B, - batch_estimator_map: Callable[[B, Self], K], - ) -> Tensor: - self.estimator.eval() - - est_kwargs = batch_estimator_map(batch, self) - output = self.estimator(**est_kwargs)[0] - output = output.detach().cpu() - - return output - - def get_batch_metrics[B]( - self, - batch: B, - batch_estimator_map: Callable[[B, Self], K], - ) -> dict[str, float]: - self.estimator.eval() - - est_kwargs = batch_estimator_map(batch, self) - metrics = self.estimator.metrics(**est_kwargs) - - return metrics + self._summary[group, name][epoch].append(value) + self._event_log[session][group][name][epoch].append(value) def save_model(self) -> None: """ diff --git a/trainlib/utils/map.py b/trainlib/utils/map.py new file mode 100644 index 0000000..6b15e5b --- /dev/null +++ b/trainlib/utils/map.py @@ -0,0 +1,10 @@ +from collections import defaultdict + + +def nested_defaultdict( + depth: int, + final: type = dict, +) -> defaultdict: + if depth == 1: + return defaultdict(final) + return defaultdict(lambda: nested_defaultdict(depth - 1, final)) diff --git a/trainlib/utils/type.py b/trainlib/utils/type.py index db625a3..5fa0118 100644 --- a/trainlib/utils/type.py +++ b/trainlib/utils/type.py @@ -11,7 +11,7 @@ from trainlib.dataset import BatchedDataset type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]] class LoaderKwargs(TypedDict, total=False): - batch_size: int + batch_size: int | None shuffle: bool sampler: Sampler | Iterable | None batch_sampler: Sampler[list] | Iterable[list] | None