From b59749c8d88836c1484444575eeb000d9057aa2b Mon Sep 17 00:00:00 2001 From: smgr Date: Sun, 22 Mar 2026 20:24:10 -0700 Subject: [PATCH] fix trainer epcoh logging --- TODO.md | 11 +++++++++++ pyproject.toml | 2 +- trainlib/estimators/mlp.py | 2 +- trainlib/plotter.py | 18 ++++++++++++++---- trainlib/trainer.py | 34 ++++++++++++++++++++++++++++------ 5 files changed, 55 insertions(+), 12 deletions(-) create mode 100644 TODO.md diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..3c87a62 --- /dev/null +++ b/TODO.md @@ -0,0 +1,11 @@ +# Long-term +- Implement a dataloader in-house, with a clear, lightweight mechanism for + collection-of-structures to structure-of-collections. For multi-proc handling + (happens in torch's dataloader, as well as the BatchedDataset for two + different purposes), we should rely on (a hopefully more stable) `execlib`. +- `Domains` may be externalized (`co3` or `convlib`) +- Up next: CLI, fully JSON-ification of model selection + train. +- Consider a "multi-train" alternative (or arg support in `train()`) for + training many "rollouts" from the same base estimator (basically forks under + different seeds). For architecture benchmarking above all, seeing average + training behavior. Consider corresponding `Plotter` methods (error bars) diff --git a/pyproject.toml b/pyproject.toml index 3aa260f..9532f3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trainlib" -version = "0.1.2" +version = "0.2.0" description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." requires-python = ">=3.13" authors = [ diff --git a/trainlib/estimators/mlp.py b/trainlib/estimators/mlp.py index d041dde..a7516bc 100644 --- a/trainlib/estimators/mlp.py +++ b/trainlib/estimators/mlp.py @@ -103,7 +103,7 @@ class MLP[Kw: MLPKwargs](Estimator[Kw]): mae = F.l1_loss(predictions, labels).item() return { - "mse": loss, + # "mse": loss, "mae": mae, "grad_norm": get_grad_norm(self) } diff --git a/trainlib/plotter.py b/trainlib/plotter.py index f696606..4fd9053 100644 --- a/trainlib/plotter.py +++ b/trainlib/plotter.py @@ -26,6 +26,8 @@ class Plotter[Kw: EstimatorKwargs]: intervals broken over the training epochs at 0, 50, 100, 150, ... and highlight the best one, even if that's not actually the single best epoch) + - Implement data and dimension limits; in the instance dataloaders have + huge numbers of samples or labels are high-dimensional """ def __init__( @@ -255,6 +257,12 @@ class Plotter[Kw: EstimatorKwargs]: return fig, axes + # def plot_ordered(...): ... + # """ + # Simple ordered view of output dimensions, with actual and output + # overlaid. + # """ + def plot_actual_output( self, row_size: int | float = 2, @@ -457,12 +465,12 @@ class Plotter[Kw: EstimatorKwargs]: combine_metrics: bool = False, transpose_layout: bool = False, figure_kwargs: SubplotsKwargs | None = None, - ): + ) -> tuple[plt.Figure, AxesArray]: session_map = self.trainer._event_log session_name = session_name or next(iter(session_map)) groups = session_map[session_name] num_metrics = len(groups[next(iter(groups))]) - colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + # colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] rows = 1 if combine_groups else len(groups) cols = 1 if combine_metrics else num_metrics @@ -513,6 +521,8 @@ class Plotter[Kw: EstimatorKwargs]: ) ax.set_title(f"[{title_prefix}] Metrics over epochs") - ax.set_xlabel("epoch", fontstyle='italic') - ax.set_ylabel("value", fontstyle='italic') + ax.set_xlabel("epoch") + ax.set_ylabel("value") ax.legend() + + return fig, axes diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 0c00eeb..cb0fd67 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -108,7 +108,7 @@ class Trainer[I, Kw: EstimatorKwargs]: Set initial tracking parameters for the primary training loop. """ - self._epoch: int = 1 + self._epoch: int = 0 self._summary = defaultdict(lambda: defaultdict(list)) self._conv_loss = float("inf") @@ -231,7 +231,9 @@ class Trainer[I, Kw: EstimatorKwargs]: for metric_name, metric_value in estimator_metrics.items(): self._log_event(label, metric_name, metric_value) - return loss_sums + avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums] + + return avg_losses def _eval_loaders( self, @@ -252,6 +254,24 @@ class Trainer[I, Kw: EstimatorKwargs]: ``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over batches. This will have no internal side effects and provides much more information (just aggregated losses are provided here). + + .. admonition:: On epoch counting + + Epoch counts start at 0 to allow for a sensible place to benchmark + the initial (potentially untrained/pre-trained) model before any + training data is seen. In the train loop, we increment the epoch + immediately, and all logging happens under the epoch value that's + set at the start of the iteration (rather than incrementing at the + end). Before beginning an additional training iteration, the + convergence condition in the ``while`` is effectively checking what + happened during the last epoch (the counter has not yet been + incremented); if no convergence, we begin again. (This is only + being noted because the epoch counting was previously quite + different: indexing started at ``1``, we incremented at the end of + the loop, and we didn't evaluate the model before the loop began. + This affects how we interpret plots and TensorBoard records, for + instance, so it's useful to spell out the approach clearly + somewhere given the many possible design choices here.) """ train_loss = self._eval_epoch(train_loader, "train") @@ -418,9 +438,10 @@ class Trainer[I, Kw: EstimatorKwargs]: self._eval_loaders(train_loader, val_loader, aux_loaders) optimizers = self.estimator.optimizers(lr=lr, eps=eps) - while self._epoch <= max_epochs and not self._converged( + while self._epoch < max_epochs and not self._converged( self._epoch, stop_after_epochs ): + self._epoch += 1 train_frac = f"{self._epoch}/{max_epochs}" stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" print(f"Training epoch {train_frac}...") @@ -434,20 +455,21 @@ class Trainer[I, Kw: EstimatorKwargs]: train_loss, val_loss, _ = self._eval_loaders( train_loader, val_loader, aux_loaders ) - self._conv_loss = sum(val_loss) if val_loss else sum(train_loss) + # determine loss to use for measuring convergence + conv_loss = val_loss if val_loss else train_loss + self._conv_loss = sum(conv_loss) / len(conv_loss) if self._epoch % summarize_every == 0: self._summarize() if self._epoch % chkpt_every == 0: self.save_model() - self._epoch += 1 return self.estimator def _converged(self, epoch: int, stop_after_epochs: int) -> bool: converged = False - if epoch == 1 or self._conv_loss < self._best_val_loss: + if epoch == 0 or self._conv_loss < self._best_val_loss: self._best_val_loss = self._conv_loss self._stagnant_epochs = 0 self._best_model_state_dict = deepcopy(self.estimator.state_dict())