fix trainer epcoh logging
This commit is contained in:
11
TODO.md
Normal file
11
TODO.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# Long-term
|
||||||
|
- Implement a dataloader in-house, with a clear, lightweight mechanism for
|
||||||
|
collection-of-structures to structure-of-collections. For multi-proc handling
|
||||||
|
(happens in torch's dataloader, as well as the BatchedDataset for two
|
||||||
|
different purposes), we should rely on (a hopefully more stable) `execlib`.
|
||||||
|
- `Domains` may be externalized (`co3` or `convlib`)
|
||||||
|
- Up next: CLI, fully JSON-ification of model selection + train.
|
||||||
|
- Consider a "multi-train" alternative (or arg support in `train()`) for
|
||||||
|
training many "rollouts" from the same base estimator (basically forks under
|
||||||
|
different seeds). For architecture benchmarking above all, seeing average
|
||||||
|
training behavior. Consider corresponding `Plotter` methods (error bars)
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "trainlib"
|
name = "trainlib"
|
||||||
version = "0.1.2"
|
version = "0.2.0"
|
||||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class MLP[Kw: MLPKwargs](Estimator[Kw]):
|
|||||||
mae = F.l1_loss(predictions, labels).item()
|
mae = F.l1_loss(predictions, labels).item()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mse": loss,
|
# "mse": loss,
|
||||||
"mae": mae,
|
"mae": mae,
|
||||||
"grad_norm": get_grad_norm(self)
|
"grad_norm": get_grad_norm(self)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ class Plotter[Kw: EstimatorKwargs]:
|
|||||||
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
||||||
highlight the best one, even if that's not actually the single best
|
highlight the best one, even if that's not actually the single best
|
||||||
epoch)
|
epoch)
|
||||||
|
- Implement data and dimension limits; in the instance dataloaders have
|
||||||
|
huge numbers of samples or labels are high-dimensional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -255,6 +257,12 @@ class Plotter[Kw: EstimatorKwargs]:
|
|||||||
|
|
||||||
return fig, axes
|
return fig, axes
|
||||||
|
|
||||||
|
# def plot_ordered(...): ...
|
||||||
|
# """
|
||||||
|
# Simple ordered view of output dimensions, with actual and output
|
||||||
|
# overlaid.
|
||||||
|
# """
|
||||||
|
|
||||||
def plot_actual_output(
|
def plot_actual_output(
|
||||||
self,
|
self,
|
||||||
row_size: int | float = 2,
|
row_size: int | float = 2,
|
||||||
@@ -457,12 +465,12 @@ class Plotter[Kw: EstimatorKwargs]:
|
|||||||
combine_metrics: bool = False,
|
combine_metrics: bool = False,
|
||||||
transpose_layout: bool = False,
|
transpose_layout: bool = False,
|
||||||
figure_kwargs: SubplotsKwargs | None = None,
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
):
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
session_map = self.trainer._event_log
|
session_map = self.trainer._event_log
|
||||||
session_name = session_name or next(iter(session_map))
|
session_name = session_name or next(iter(session_map))
|
||||||
groups = session_map[session_name]
|
groups = session_map[session_name]
|
||||||
num_metrics = len(groups[next(iter(groups))])
|
num_metrics = len(groups[next(iter(groups))])
|
||||||
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||||
|
|
||||||
rows = 1 if combine_groups else len(groups)
|
rows = 1 if combine_groups else len(groups)
|
||||||
cols = 1 if combine_metrics else num_metrics
|
cols = 1 if combine_metrics else num_metrics
|
||||||
@@ -513,6 +521,8 @@ class Plotter[Kw: EstimatorKwargs]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax.set_title(f"[{title_prefix}] Metrics over epochs")
|
ax.set_title(f"[{title_prefix}] Metrics over epochs")
|
||||||
ax.set_xlabel("epoch", fontstyle='italic')
|
ax.set_xlabel("epoch")
|
||||||
ax.set_ylabel("value", fontstyle='italic')
|
ax.set_ylabel("value")
|
||||||
ax.legend()
|
ax.legend()
|
||||||
|
|
||||||
|
return fig, axes
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
Set initial tracking parameters for the primary training loop.
|
Set initial tracking parameters for the primary training loop.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._epoch: int = 1
|
self._epoch: int = 0
|
||||||
self._summary = defaultdict(lambda: defaultdict(list))
|
self._summary = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
self._conv_loss = float("inf")
|
self._conv_loss = float("inf")
|
||||||
@@ -231,7 +231,9 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
for metric_name, metric_value in estimator_metrics.items():
|
for metric_name, metric_value in estimator_metrics.items():
|
||||||
self._log_event(label, metric_name, metric_value)
|
self._log_event(label, metric_name, metric_value)
|
||||||
|
|
||||||
return loss_sums
|
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
|
||||||
|
|
||||||
|
return avg_losses
|
||||||
|
|
||||||
def _eval_loaders(
|
def _eval_loaders(
|
||||||
self,
|
self,
|
||||||
@@ -252,6 +254,24 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
|
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
|
||||||
batches. This will have no internal side effects and provides much more
|
batches. This will have no internal side effects and provides much more
|
||||||
information (just aggregated losses are provided here).
|
information (just aggregated losses are provided here).
|
||||||
|
|
||||||
|
.. admonition:: On epoch counting
|
||||||
|
|
||||||
|
Epoch counts start at 0 to allow for a sensible place to benchmark
|
||||||
|
the initial (potentially untrained/pre-trained) model before any
|
||||||
|
training data is seen. In the train loop, we increment the epoch
|
||||||
|
immediately, and all logging happens under the epoch value that's
|
||||||
|
set at the start of the iteration (rather than incrementing at the
|
||||||
|
end). Before beginning an additional training iteration, the
|
||||||
|
convergence condition in the ``while`` is effectively checking what
|
||||||
|
happened during the last epoch (the counter has not yet been
|
||||||
|
incremented); if no convergence, we begin again. (This is only
|
||||||
|
being noted because the epoch counting was previously quite
|
||||||
|
different: indexing started at ``1``, we incremented at the end of
|
||||||
|
the loop, and we didn't evaluate the model before the loop began.
|
||||||
|
This affects how we interpret plots and TensorBoard records, for
|
||||||
|
instance, so it's useful to spell out the approach clearly
|
||||||
|
somewhere given the many possible design choices here.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
train_loss = self._eval_epoch(train_loader, "train")
|
train_loss = self._eval_epoch(train_loader, "train")
|
||||||
@@ -418,9 +438,10 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
self._eval_loaders(train_loader, val_loader, aux_loaders)
|
self._eval_loaders(train_loader, val_loader, aux_loaders)
|
||||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||||
|
|
||||||
while self._epoch <= max_epochs and not self._converged(
|
while self._epoch < max_epochs and not self._converged(
|
||||||
self._epoch, stop_after_epochs
|
self._epoch, stop_after_epochs
|
||||||
):
|
):
|
||||||
|
self._epoch += 1
|
||||||
train_frac = f"{self._epoch}/{max_epochs}"
|
train_frac = f"{self._epoch}/{max_epochs}"
|
||||||
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||||
print(f"Training epoch {train_frac}...")
|
print(f"Training epoch {train_frac}...")
|
||||||
@@ -434,20 +455,21 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
train_loss, val_loss, _ = self._eval_loaders(
|
train_loss, val_loss, _ = self._eval_loaders(
|
||||||
train_loader, val_loader, aux_loaders
|
train_loader, val_loader, aux_loaders
|
||||||
)
|
)
|
||||||
self._conv_loss = sum(val_loss) if val_loss else sum(train_loss)
|
# determine loss to use for measuring convergence
|
||||||
|
conv_loss = val_loss if val_loss else train_loss
|
||||||
|
self._conv_loss = sum(conv_loss) / len(conv_loss)
|
||||||
|
|
||||||
if self._epoch % summarize_every == 0:
|
if self._epoch % summarize_every == 0:
|
||||||
self._summarize()
|
self._summarize()
|
||||||
if self._epoch % chkpt_every == 0:
|
if self._epoch % chkpt_every == 0:
|
||||||
self.save_model()
|
self.save_model()
|
||||||
self._epoch += 1
|
|
||||||
|
|
||||||
return self.estimator
|
return self.estimator
|
||||||
|
|
||||||
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
||||||
converged = False
|
converged = False
|
||||||
|
|
||||||
if epoch == 1 or self._conv_loss < self._best_val_loss:
|
if epoch == 0 or self._conv_loss < self._best_val_loss:
|
||||||
self._best_val_loss = self._conv_loss
|
self._best_val_loss = self._conv_loss
|
||||||
self._stagnant_epochs = 0
|
self._stagnant_epochs = 0
|
||||||
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
||||||
|
|||||||
Reference in New Issue
Block a user