update train loop eval logic

This commit is contained in:
2026-03-31 22:52:27 -07:00
parent ba0c804d5e
commit fdccb4c5eb
7 changed files with 116 additions and 29 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "trainlib" name = "trainlib"
version = "0.3.0" version = "0.3.1"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13" requires-python = ">=3.13"
authors = [ authors = [

View File

@@ -108,12 +108,26 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
offset: int = 0, offset: int = 0,
lookahead: int = 1, lookahead: int = 1,
num_windows: int = 1, num_windows: int = 1,
pad_mode: str = "constant",
#fill_with: str = "zero",
**kwargs: Unpack[DatasetKwargs], **kwargs: Unpack[DatasetKwargs],
) -> None: ) -> None:
"""
Parameters:
TODO: implement options for `fill_with`; currently just passing
through a `pad_mode` the Functional call, which does the job
fill_with: strategy to use for padding values in windows
- `zero`: fill with zeros
- `left`: use nearest window column (repeat leftmost)
- `mean`: fill with the window mean
"""
self.lookback = lookback self.lookback = lookback
self.offset = offset self.offset = offset
self.lookahead = lookahead self.lookahead = lookahead
self.num_windows = num_windows self.num_windows = num_windows
self.pad_mode = pad_mode
super().__init__(domain, **kwargs) super().__init__(domain, **kwargs)
@@ -123,7 +137,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
batch_index: int, batch_index: int,
) -> list[tuple[Tensor, ...]]: ) -> list[tuple[Tensor, ...]]:
""" """
Backward pads first sequence over (lookback-1) length, and steps the Backward pads window sequences over (lookback-1) length, and steps the
remaining items forward by the lookahead. remaining items forward by the lookahead.
Batch data: Batch data:
@@ -166,7 +180,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
exceeds the offset. exceeds the offset.
To get windows starting with the first index at the left: we first set To get windows starting with the first index at the left: we first set
out window size (call it L), determined by `lookback`. Then the our window size (call it L), determined by `lookback`. Then the
rightmost index we want will be `L-1`, which determines our `offset` rightmost index we want will be `L-1`, which determines our `offset`
setting. setting.
@@ -193,7 +207,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
# for window sized `lb`, we pad with `lb-1` zeros. We then take off # for window sized `lb`, we pad with `lb-1` zeros. We then take off
# the amount of our offset, which in the extreme cases does no # the amount of our offset, which in the extreme cases does no
# padding. # padding.
xip = F.pad(t, ((lb-1) - off, 0)) xip = F.pad(t, ((lb-1) - off, 0), mode=self.pad_mode)
# extract sliding windows over the padded tensor # extract sliding windows over the padded tensor
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for # unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for

View File

@@ -8,6 +8,7 @@ from torch import nn, Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from trainlib.utils import op
from trainlib.estimator import Estimator, EstimatorKwargs from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.type import OptimizerKwargs from trainlib.utils.type import OptimizerKwargs
from trainlib.utils.module import get_grad_norm from trainlib.utils.module import get_grad_norm
@@ -102,6 +103,7 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
labels = kwargs["labels"] labels = kwargs["labels"]
yield F.mse_loss(predictions, labels) yield F.mse_loss(predictions, labels)
#yield F.l1_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad(): with torch.no_grad():
@@ -109,12 +111,16 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
predictions = self(**kwargs)[0] predictions = self(**kwargs)[0]
labels = kwargs["labels"] labels = kwargs["labels"]
mse = F.mse_loss(predictions, labels).item()
mae = F.l1_loss(predictions, labels).item() mae = F.l1_loss(predictions, labels).item()
r2 = op.r2_score(predictions, labels).item()
return { return {
# "loss": loss, "loss": loss,
"mse": loss, "mse": mse,
"mae": mae, "mae": mae,
"r2": r2,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }
@@ -377,7 +383,8 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
# will be (B, T, C), applies indep at each time step across channels # will be (B, T, C), applies indep at each time step across channels
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim) # self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
# will be (B, C, T), applies indep at each time step across channels # will be (B, Co, 1), applies indep at each channel across temporal dim
# size time steps
self.dense_z = TDNNLayer( self.dense_z = TDNNLayer(
layer_in_dim, layer_in_dim,
self.output_dim, self.output_dim,
@@ -438,6 +445,7 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
predictions = predictions.squeeze(-1) predictions = predictions.squeeze(-1)
yield F.mse_loss(predictions, labels, reduction="mean") yield F.mse_loss(predictions, labels, reduction="mean")
#yield F.l1_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad(): with torch.no_grad():
@@ -445,11 +453,16 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
predictions = self(**kwargs)[0].squeeze(-1) predictions = self(**kwargs)[0].squeeze(-1)
labels = kwargs["labels"] labels = kwargs["labels"]
mse = F.mse_loss(predictions, labels).item()
mae = F.l1_loss(predictions, labels).item() mae = F.l1_loss(predictions, labels).item()
r2 = op.r2_score(predictions, labels).item()
return { return {
"mse": loss, "loss": loss,
"mse": mse,
"mae": mae, "mae": mae,
"r2": r2,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }

View File

@@ -64,14 +64,17 @@ class Plotter[Kw: EstimatorKwargs]:
for i, loader in enumerate(self.dataloaders): for i, loader in enumerate(self.dataloaders):
label = self.dataloader_labels[i] label = self.dataloader_labels[i]
actual = torch.cat([ actual = [
self.kw_to_actual(batch_kwargs).detach().cpu() self.kw_to_actual(batch_kwargs).detach().cpu()
for batch_kwargs in loader for batch_kwargs in loader
]) ]
output = torch.cat([ actual = torch.cat([ai.reshape(*([*ai.shape]+[1])[:2]) for ai in actual])
output = [
self.trainer.estimator(**batch_kwargs)[0].detach().cpu() self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
for batch_kwargs in loader for batch_kwargs in loader
]) ]
output = torch.cat([oi.reshape(*([*oi.shape]+[1])[:2]) for oi in output])
data_tuples.append((actual, output, label)) data_tuples.append((actual, output, label))

View File

@@ -104,18 +104,32 @@ class Trainer[I, Kw: EstimatorKwargs]:
self.reset() self.reset()
def reset(self) -> None: def reset(self, resume: bool = False) -> None:
""" """
Set initial tracking parameters for the primary training loop. Set initial tracking parameters for the primary training loop.
Parameters:
resume: if ``True``, just resets the stagnant epoch counter, with
the aims of continuing any existing training state under
resumed ``train()`` call. This should likely only be set when
training is continued on the same dataset and the goal is to
resume convergence loss-based scoring for a fresh set of
epochs. If even that element of the training loop should resume
(which should only happen if a training loop was interrupted or
a max epoch limit was reached), then this method shouldn't be
called at all between ``train()`` invocations.
""" """
self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(list))
self._conv_loss = float("inf")
self._best_conv_loss = float("inf")
self._stagnant_epochs = 0 self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {}
if not resume:
self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(list))
self._conv_loss = float("inf")
self._best_conv_loss = float("inf")
self._best_conv_epoch = 0
self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch( def _train_epoch(
self, self,
@@ -459,15 +473,18 @@ class Trainer[I, Kw: EstimatorKwargs]:
progress_bar = tqdm(train_loader, unit="batch") progress_bar = tqdm(train_loader, unit="batch")
# evaluate model on dataloaders once before training starts # evaluate model on dataloaders once before training starts
self._eval_loaders(train_loader, val_loader, aux_loaders, progress_bar) train_loss, val_loss, *_ = self._eval_loaders(
train_loader, val_loader, aux_loaders, progress_bar
)
conv_loss = val_loss if val_loss else train_loss
self._conv_loss = sum(conv_loss) / len(conv_loss)
optimizers = self.estimator.optimizers(lr=lr, eps=eps) optimizers = self.estimator.optimizers(lr=lr, eps=eps)
while self._epoch < max_epochs and not self._converged( while self._epoch < max_epochs and not self._converged(stop_after_epochs):
self._epoch, stop_after_epochs
):
self._epoch += 1 self._epoch += 1
train_frac = f"{self._epoch}/{max_epochs}" #train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" #stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
#print(f"Training epoch {train_frac}...") #print(f"Training epoch {train_frac}...")
#print(f"Stagnant epochs {stag_frac}...") #print(f"Stagnant epochs {stag_frac}...")
@@ -495,12 +512,43 @@ class Trainer[I, Kw: EstimatorKwargs]:
return self.estimator return self.estimator
def _converged(self, epoch: int, stop_after_epochs: int) -> bool: def _converged(self, stop_after_epochs: int) -> bool:
"""
Check if model has converged.
This method looks at the current "convergence loss" (validation-based
if a val set is provided to ``train()``, otherwise the training loss is
used), checking if it's the best yet recorded, incrementing the
stagnancy count if not. Convergence is asserted only if the number of
stagnant epochs exceeds ``stop_after_epochs``.
.. admonition:: Evaluation order
Convergence losses are recorded before the first training update,
so initial model states are appropriately benchmarked by the time
``_converged()`` is invoked.
If resuming training on the same dataset, one might expect only to
reset the stagnant epoch counter: you'll resume from the last
epoch, estimator state, and best seen loss, while allowed
``stop_after_epochs`` more chances for better validation.
If picking up training on a new dataset, even a training+validation
setting, resetting the best seen loss and best model state is
needed: you can't reliably compare the existing stats under new
data. It's somewhat ambiguous whether ``epoch`` absolutely must be
reset; you could continue logging metrics under the same named
session. But best practices would suggest restarting the epoch
count and have events logged under a new session heading when data
change.
"""
converged = False converged = False
if epoch == 0 or self._conv_loss < self._best_val_loss: if self._conv_loss < self._best_conv_loss:
self._best_val_loss = self._conv_loss
self._stagnant_epochs = 0 self._stagnant_epochs = 0
self._best_conv_loss = self._conv_loss
self._best_conv_epoch = self._epoch
self._best_model_state_dict = deepcopy(self.estimator.state_dict()) self._best_model_state_dict = deepcopy(self.estimator.state_dict())
else: else:
self._stagnant_epochs += 1 self._stagnant_epochs += 1

9
trainlib/utils/op.py Normal file
View File

@@ -0,0 +1,9 @@
from torch import Tensor
def r2_score(y: Tensor, y_hat: Tensor) -> Tensor:
ss_res = ((y - y_hat)**2).sum()
ss_tot = ((y - y.mean())**2).sum()
r2 = 1 - ss_res / ss_tot
return r2

2
uv.lock generated
View File

@@ -1659,7 +1659,7 @@ wheels = [
[[package]] [[package]]
name = "trainlib" name = "trainlib"
version = "0.2.1" version = "0.3.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "colorama" }, { name = "colorama" },