fix trainer epcoh logging

This commit is contained in:
2026-03-22 20:24:10 -07:00
parent a395a08d5c
commit b59749c8d8
5 changed files with 55 additions and 12 deletions

11
TODO.md Normal file
View File

@@ -0,0 +1,11 @@
# Long-term
- Implement a dataloader in-house, with a clear, lightweight mechanism for
collection-of-structures to structure-of-collections. For multi-proc handling
(happens in torch's dataloader, as well as the BatchedDataset for two
different purposes), we should rely on (a hopefully more stable) `execlib`.
- `Domains` may be externalized (`co3` or `convlib`)
- Up next: CLI, fully JSON-ification of model selection + train.
- Consider a "multi-train" alternative (or arg support in `train()`) for
training many "rollouts" from the same base estimator (basically forks under
different seeds). For architecture benchmarking above all, seeing average
training behavior. Consider corresponding `Plotter` methods (error bars)

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "trainlib"
version = "0.1.2"
version = "0.2.0"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13"
authors = [

View File

@@ -103,7 +103,7 @@ class MLP[Kw: MLPKwargs](Estimator[Kw]):
mae = F.l1_loss(predictions, labels).item()
return {
"mse": loss,
# "mse": loss,
"mae": mae,
"grad_norm": get_grad_norm(self)
}

View File

@@ -26,6 +26,8 @@ class Plotter[Kw: EstimatorKwargs]:
intervals broken over the training epochs at 0, 50, 100, 150, ... and
highlight the best one, even if that's not actually the single best
epoch)
- Implement data and dimension limits; in the instance dataloaders have
huge numbers of samples or labels are high-dimensional
"""
def __init__(
@@ -255,6 +257,12 @@ class Plotter[Kw: EstimatorKwargs]:
return fig, axes
# def plot_ordered(...): ...
# """
# Simple ordered view of output dimensions, with actual and output
# overlaid.
# """
def plot_actual_output(
self,
row_size: int | float = 2,
@@ -457,12 +465,12 @@ class Plotter[Kw: EstimatorKwargs]:
combine_metrics: bool = False,
transpose_layout: bool = False,
figure_kwargs: SubplotsKwargs | None = None,
):
) -> tuple[plt.Figure, AxesArray]:
session_map = self.trainer._event_log
session_name = session_name or next(iter(session_map))
groups = session_map[session_name]
num_metrics = len(groups[next(iter(groups))])
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
rows = 1 if combine_groups else len(groups)
cols = 1 if combine_metrics else num_metrics
@@ -513,6 +521,8 @@ class Plotter[Kw: EstimatorKwargs]:
)
ax.set_title(f"[{title_prefix}] Metrics over epochs")
ax.set_xlabel("epoch", fontstyle='italic')
ax.set_ylabel("value", fontstyle='italic')
ax.set_xlabel("epoch")
ax.set_ylabel("value")
ax.legend()
return fig, axes

View File

@@ -108,7 +108,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
Set initial tracking parameters for the primary training loop.
"""
self._epoch: int = 1
self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(list))
self._conv_loss = float("inf")
@@ -231,7 +231,9 @@ class Trainer[I, Kw: EstimatorKwargs]:
for metric_name, metric_value in estimator_metrics.items():
self._log_event(label, metric_name, metric_value)
return loss_sums
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
return avg_losses
def _eval_loaders(
self,
@@ -252,6 +254,24 @@ class Trainer[I, Kw: EstimatorKwargs]:
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
batches. This will have no internal side effects and provides much more
information (just aggregated losses are provided here).
.. admonition:: On epoch counting
Epoch counts start at 0 to allow for a sensible place to benchmark
the initial (potentially untrained/pre-trained) model before any
training data is seen. In the train loop, we increment the epoch
immediately, and all logging happens under the epoch value that's
set at the start of the iteration (rather than incrementing at the
end). Before beginning an additional training iteration, the
convergence condition in the ``while`` is effectively checking what
happened during the last epoch (the counter has not yet been
incremented); if no convergence, we begin again. (This is only
being noted because the epoch counting was previously quite
different: indexing started at ``1``, we incremented at the end of
the loop, and we didn't evaluate the model before the loop began.
This affects how we interpret plots and TensorBoard records, for
instance, so it's useful to spell out the approach clearly
somewhere given the many possible design choices here.)
"""
train_loss = self._eval_epoch(train_loader, "train")
@@ -418,9 +438,10 @@ class Trainer[I, Kw: EstimatorKwargs]:
self._eval_loaders(train_loader, val_loader, aux_loaders)
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
while self._epoch <= max_epochs and not self._converged(
while self._epoch < max_epochs and not self._converged(
self._epoch, stop_after_epochs
):
self._epoch += 1
train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...")
@@ -434,20 +455,21 @@ class Trainer[I, Kw: EstimatorKwargs]:
train_loss, val_loss, _ = self._eval_loaders(
train_loader, val_loader, aux_loaders
)
self._conv_loss = sum(val_loss) if val_loss else sum(train_loss)
# determine loss to use for measuring convergence
conv_loss = val_loss if val_loss else train_loss
self._conv_loss = sum(conv_loss) / len(conv_loss)
if self._epoch % summarize_every == 0:
self._summarize()
if self._epoch % chkpt_every == 0:
self.save_model()
self._epoch += 1
return self.estimator
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
converged = False
if epoch == 1 or self._conv_loss < self._best_val_loss:
if epoch == 0 or self._conv_loss < self._best_val_loss:
self._best_val_loss = self._conv_loss
self._stagnant_epochs = 0
self._best_model_state_dict = deepcopy(self.estimator.state_dict())