fix trainer epcoh logging
This commit is contained in:
11
TODO.md
Normal file
11
TODO.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# Long-term
|
||||
- Implement a dataloader in-house, with a clear, lightweight mechanism for
|
||||
collection-of-structures to structure-of-collections. For multi-proc handling
|
||||
(happens in torch's dataloader, as well as the BatchedDataset for two
|
||||
different purposes), we should rely on (a hopefully more stable) `execlib`.
|
||||
- `Domains` may be externalized (`co3` or `convlib`)
|
||||
- Up next: CLI, fully JSON-ification of model selection + train.
|
||||
- Consider a "multi-train" alternative (or arg support in `train()`) for
|
||||
training many "rollouts" from the same base estimator (basically forks under
|
||||
different seeds). For architecture benchmarking above all, seeing average
|
||||
training behavior. Consider corresponding `Plotter` methods (error bars)
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "trainlib"
|
||||
version = "0.1.2"
|
||||
version = "0.2.0"
|
||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||
requires-python = ">=3.13"
|
||||
authors = [
|
||||
|
||||
@@ -103,7 +103,7 @@ class MLP[Kw: MLPKwargs](Estimator[Kw]):
|
||||
mae = F.l1_loss(predictions, labels).item()
|
||||
|
||||
return {
|
||||
"mse": loss,
|
||||
# "mse": loss,
|
||||
"mae": mae,
|
||||
"grad_norm": get_grad_norm(self)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ class Plotter[Kw: EstimatorKwargs]:
|
||||
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
||||
highlight the best one, even if that's not actually the single best
|
||||
epoch)
|
||||
- Implement data and dimension limits; in the instance dataloaders have
|
||||
huge numbers of samples or labels are high-dimensional
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -255,6 +257,12 @@ class Plotter[Kw: EstimatorKwargs]:
|
||||
|
||||
return fig, axes
|
||||
|
||||
# def plot_ordered(...): ...
|
||||
# """
|
||||
# Simple ordered view of output dimensions, with actual and output
|
||||
# overlaid.
|
||||
# """
|
||||
|
||||
def plot_actual_output(
|
||||
self,
|
||||
row_size: int | float = 2,
|
||||
@@ -457,12 +465,12 @@ class Plotter[Kw: EstimatorKwargs]:
|
||||
combine_metrics: bool = False,
|
||||
transpose_layout: bool = False,
|
||||
figure_kwargs: SubplotsKwargs | None = None,
|
||||
):
|
||||
) -> tuple[plt.Figure, AxesArray]:
|
||||
session_map = self.trainer._event_log
|
||||
session_name = session_name or next(iter(session_map))
|
||||
groups = session_map[session_name]
|
||||
num_metrics = len(groups[next(iter(groups))])
|
||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||
|
||||
rows = 1 if combine_groups else len(groups)
|
||||
cols = 1 if combine_metrics else num_metrics
|
||||
@@ -513,6 +521,8 @@ class Plotter[Kw: EstimatorKwargs]:
|
||||
)
|
||||
|
||||
ax.set_title(f"[{title_prefix}] Metrics over epochs")
|
||||
ax.set_xlabel("epoch", fontstyle='italic')
|
||||
ax.set_ylabel("value", fontstyle='italic')
|
||||
ax.set_xlabel("epoch")
|
||||
ax.set_ylabel("value")
|
||||
ax.legend()
|
||||
|
||||
return fig, axes
|
||||
|
||||
@@ -108,7 +108,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
||||
Set initial tracking parameters for the primary training loop.
|
||||
"""
|
||||
|
||||
self._epoch: int = 1
|
||||
self._epoch: int = 0
|
||||
self._summary = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
self._conv_loss = float("inf")
|
||||
@@ -231,7 +231,9 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._log_event(label, metric_name, metric_value)
|
||||
|
||||
return loss_sums
|
||||
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
|
||||
|
||||
return avg_losses
|
||||
|
||||
def _eval_loaders(
|
||||
self,
|
||||
@@ -252,6 +254,24 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
||||
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
|
||||
batches. This will have no internal side effects and provides much more
|
||||
information (just aggregated losses are provided here).
|
||||
|
||||
.. admonition:: On epoch counting
|
||||
|
||||
Epoch counts start at 0 to allow for a sensible place to benchmark
|
||||
the initial (potentially untrained/pre-trained) model before any
|
||||
training data is seen. In the train loop, we increment the epoch
|
||||
immediately, and all logging happens under the epoch value that's
|
||||
set at the start of the iteration (rather than incrementing at the
|
||||
end). Before beginning an additional training iteration, the
|
||||
convergence condition in the ``while`` is effectively checking what
|
||||
happened during the last epoch (the counter has not yet been
|
||||
incremented); if no convergence, we begin again. (This is only
|
||||
being noted because the epoch counting was previously quite
|
||||
different: indexing started at ``1``, we incremented at the end of
|
||||
the loop, and we didn't evaluate the model before the loop began.
|
||||
This affects how we interpret plots and TensorBoard records, for
|
||||
instance, so it's useful to spell out the approach clearly
|
||||
somewhere given the many possible design choices here.)
|
||||
"""
|
||||
|
||||
train_loss = self._eval_epoch(train_loader, "train")
|
||||
@@ -418,9 +438,10 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
||||
self._eval_loaders(train_loader, val_loader, aux_loaders)
|
||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||
|
||||
while self._epoch <= max_epochs and not self._converged(
|
||||
while self._epoch < max_epochs and not self._converged(
|
||||
self._epoch, stop_after_epochs
|
||||
):
|
||||
self._epoch += 1
|
||||
train_frac = f"{self._epoch}/{max_epochs}"
|
||||
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||
print(f"Training epoch {train_frac}...")
|
||||
@@ -434,20 +455,21 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
||||
train_loss, val_loss, _ = self._eval_loaders(
|
||||
train_loader, val_loader, aux_loaders
|
||||
)
|
||||
self._conv_loss = sum(val_loss) if val_loss else sum(train_loss)
|
||||
# determine loss to use for measuring convergence
|
||||
conv_loss = val_loss if val_loss else train_loss
|
||||
self._conv_loss = sum(conv_loss) / len(conv_loss)
|
||||
|
||||
if self._epoch % summarize_every == 0:
|
||||
self._summarize()
|
||||
if self._epoch % chkpt_every == 0:
|
||||
self.save_model()
|
||||
self._epoch += 1
|
||||
|
||||
return self.estimator
|
||||
|
||||
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
||||
converged = False
|
||||
|
||||
if epoch == 1 or self._conv_loss < self._best_val_loss:
|
||||
if epoch == 0 or self._conv_loss < self._best_val_loss:
|
||||
self._best_val_loss = self._conv_loss
|
||||
self._stagnant_epochs = 0
|
||||
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
||||
|
||||
Reference in New Issue
Block a user