Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fdccb4c5eb |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "trainlib"
|
name = "trainlib"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
@@ -108,12 +108,26 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
lookahead: int = 1,
|
lookahead: int = 1,
|
||||||
num_windows: int = 1,
|
num_windows: int = 1,
|
||||||
|
pad_mode: str = "constant",
|
||||||
|
#fill_with: str = "zero",
|
||||||
**kwargs: Unpack[DatasetKwargs],
|
**kwargs: Unpack[DatasetKwargs],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
TODO: implement options for `fill_with`; currently just passing
|
||||||
|
through a `pad_mode` the Functional call, which does the job
|
||||||
|
|
||||||
|
fill_with: strategy to use for padding values in windows
|
||||||
|
- `zero`: fill with zeros
|
||||||
|
- `left`: use nearest window column (repeat leftmost)
|
||||||
|
- `mean`: fill with the window mean
|
||||||
|
"""
|
||||||
|
|
||||||
self.lookback = lookback
|
self.lookback = lookback
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
self.lookahead = lookahead
|
self.lookahead = lookahead
|
||||||
self.num_windows = num_windows
|
self.num_windows = num_windows
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
super().__init__(domain, **kwargs)
|
super().__init__(domain, **kwargs)
|
||||||
|
|
||||||
@@ -123,7 +137,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
batch_index: int,
|
batch_index: int,
|
||||||
) -> list[tuple[Tensor, ...]]:
|
) -> list[tuple[Tensor, ...]]:
|
||||||
"""
|
"""
|
||||||
Backward pads first sequence over (lookback-1) length, and steps the
|
Backward pads window sequences over (lookback-1) length, and steps the
|
||||||
remaining items forward by the lookahead.
|
remaining items forward by the lookahead.
|
||||||
|
|
||||||
Batch data:
|
Batch data:
|
||||||
@@ -166,7 +180,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
exceeds the offset.
|
exceeds the offset.
|
||||||
|
|
||||||
To get windows starting with the first index at the left: we first set
|
To get windows starting with the first index at the left: we first set
|
||||||
out window size (call it L), determined by `lookback`. Then the
|
our window size (call it L), determined by `lookback`. Then the
|
||||||
rightmost index we want will be `L-1`, which determines our `offset`
|
rightmost index we want will be `L-1`, which determines our `offset`
|
||||||
setting.
|
setting.
|
||||||
|
|
||||||
@@ -193,7 +207,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
|
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
|
||||||
# the amount of our offset, which in the extreme cases does no
|
# the amount of our offset, which in the extreme cases does no
|
||||||
# padding.
|
# padding.
|
||||||
xip = F.pad(t, ((lb-1) - off, 0))
|
xip = F.pad(t, ((lb-1) - off, 0), mode=self.pad_mode)
|
||||||
|
|
||||||
# extract sliding windows over the padded tensor
|
# extract sliding windows over the padded tensor
|
||||||
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for
|
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from torch import nn, Tensor
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from trainlib.utils import op
|
||||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||||
from trainlib.utils.type import OptimizerKwargs
|
from trainlib.utils.type import OptimizerKwargs
|
||||||
from trainlib.utils.module import get_grad_norm
|
from trainlib.utils.module import get_grad_norm
|
||||||
@@ -102,6 +103,7 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
yield F.mse_loss(predictions, labels)
|
yield F.mse_loss(predictions, labels)
|
||||||
|
#yield F.l1_loss(predictions, labels)
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -109,12 +111,16 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
|
|
||||||
predictions = self(**kwargs)[0]
|
predictions = self(**kwargs)[0]
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
|
mse = F.mse_loss(predictions, labels).item()
|
||||||
mae = F.l1_loss(predictions, labels).item()
|
mae = F.l1_loss(predictions, labels).item()
|
||||||
|
r2 = op.r2_score(predictions, labels).item()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
# "loss": loss,
|
"loss": loss,
|
||||||
"mse": loss,
|
"mse": mse,
|
||||||
"mae": mae,
|
"mae": mae,
|
||||||
|
"r2": r2,
|
||||||
"grad_norm": get_grad_norm(self)
|
"grad_norm": get_grad_norm(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,7 +383,8 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
# will be (B, T, C), applies indep at each time step across channels
|
# will be (B, T, C), applies indep at each time step across channels
|
||||||
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
|
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
|
||||||
|
|
||||||
# will be (B, C, T), applies indep at each time step across channels
|
# will be (B, Co, 1), applies indep at each channel across temporal dim
|
||||||
|
# size time steps
|
||||||
self.dense_z = TDNNLayer(
|
self.dense_z = TDNNLayer(
|
||||||
layer_in_dim,
|
layer_in_dim,
|
||||||
self.output_dim,
|
self.output_dim,
|
||||||
@@ -438,6 +445,7 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
predictions = predictions.squeeze(-1)
|
predictions = predictions.squeeze(-1)
|
||||||
|
|
||||||
yield F.mse_loss(predictions, labels, reduction="mean")
|
yield F.mse_loss(predictions, labels, reduction="mean")
|
||||||
|
#yield F.l1_loss(predictions, labels)
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -445,11 +453,16 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
|
|
||||||
predictions = self(**kwargs)[0].squeeze(-1)
|
predictions = self(**kwargs)[0].squeeze(-1)
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
|
mse = F.mse_loss(predictions, labels).item()
|
||||||
mae = F.l1_loss(predictions, labels).item()
|
mae = F.l1_loss(predictions, labels).item()
|
||||||
|
r2 = op.r2_score(predictions, labels).item()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mse": loss,
|
"loss": loss,
|
||||||
|
"mse": mse,
|
||||||
"mae": mae,
|
"mae": mae,
|
||||||
|
"r2": r2,
|
||||||
"grad_norm": get_grad_norm(self)
|
"grad_norm": get_grad_norm(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -64,14 +64,17 @@ class Plotter[Kw: EstimatorKwargs]:
|
|||||||
for i, loader in enumerate(self.dataloaders):
|
for i, loader in enumerate(self.dataloaders):
|
||||||
label = self.dataloader_labels[i]
|
label = self.dataloader_labels[i]
|
||||||
|
|
||||||
actual = torch.cat([
|
actual = [
|
||||||
self.kw_to_actual(batch_kwargs).detach().cpu()
|
self.kw_to_actual(batch_kwargs).detach().cpu()
|
||||||
for batch_kwargs in loader
|
for batch_kwargs in loader
|
||||||
])
|
]
|
||||||
output = torch.cat([
|
actual = torch.cat([ai.reshape(*([*ai.shape]+[1])[:2]) for ai in actual])
|
||||||
|
|
||||||
|
output = [
|
||||||
self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
|
self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
|
||||||
for batch_kwargs in loader
|
for batch_kwargs in loader
|
||||||
])
|
]
|
||||||
|
output = torch.cat([oi.reshape(*([*oi.shape]+[1])[:2]) for oi in output])
|
||||||
|
|
||||||
data_tuples.append((actual, output, label))
|
data_tuples.append((actual, output, label))
|
||||||
|
|
||||||
|
|||||||
@@ -104,17 +104,31 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self, resume: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Set initial tracking parameters for the primary training loop.
|
Set initial tracking parameters for the primary training loop.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
resume: if ``True``, just resets the stagnant epoch counter, with
|
||||||
|
the aims of continuing any existing training state under
|
||||||
|
resumed ``train()`` call. This should likely only be set when
|
||||||
|
training is continued on the same dataset and the goal is to
|
||||||
|
resume convergence loss-based scoring for a fresh set of
|
||||||
|
epochs. If even that element of the training loop should resume
|
||||||
|
(which should only happen if a training loop was interrupted or
|
||||||
|
a max epoch limit was reached), then this method shouldn't be
|
||||||
|
called at all between ``train()`` invocations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self._stagnant_epochs = 0
|
||||||
|
|
||||||
|
if not resume:
|
||||||
self._epoch: int = 0
|
self._epoch: int = 0
|
||||||
self._summary = defaultdict(lambda: defaultdict(list))
|
self._summary = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
self._conv_loss = float("inf")
|
self._conv_loss = float("inf")
|
||||||
self._best_conv_loss = float("inf")
|
self._best_conv_loss = float("inf")
|
||||||
self._stagnant_epochs = 0
|
self._best_conv_epoch = 0
|
||||||
self._best_model_state_dict: dict[str, Any] = {}
|
self._best_model_state_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
def _train_epoch(
|
def _train_epoch(
|
||||||
@@ -459,15 +473,18 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
progress_bar = tqdm(train_loader, unit="batch")
|
progress_bar = tqdm(train_loader, unit="batch")
|
||||||
|
|
||||||
# evaluate model on dataloaders once before training starts
|
# evaluate model on dataloaders once before training starts
|
||||||
self._eval_loaders(train_loader, val_loader, aux_loaders, progress_bar)
|
train_loss, val_loss, *_ = self._eval_loaders(
|
||||||
|
train_loader, val_loader, aux_loaders, progress_bar
|
||||||
|
)
|
||||||
|
conv_loss = val_loss if val_loss else train_loss
|
||||||
|
self._conv_loss = sum(conv_loss) / len(conv_loss)
|
||||||
|
|
||||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||||
|
|
||||||
while self._epoch < max_epochs and not self._converged(
|
while self._epoch < max_epochs and not self._converged(stop_after_epochs):
|
||||||
self._epoch, stop_after_epochs
|
|
||||||
):
|
|
||||||
self._epoch += 1
|
self._epoch += 1
|
||||||
train_frac = f"{self._epoch}/{max_epochs}"
|
#train_frac = f"{self._epoch}/{max_epochs}"
|
||||||
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
#stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||||
#print(f"Training epoch {train_frac}...")
|
#print(f"Training epoch {train_frac}...")
|
||||||
#print(f"Stagnant epochs {stag_frac}...")
|
#print(f"Stagnant epochs {stag_frac}...")
|
||||||
|
|
||||||
@@ -495,12 +512,43 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
|
|
||||||
return self.estimator
|
return self.estimator
|
||||||
|
|
||||||
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
def _converged(self, stop_after_epochs: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model has converged.
|
||||||
|
|
||||||
|
This method looks at the current "convergence loss" (validation-based
|
||||||
|
if a val set is provided to ``train()``, otherwise the training loss is
|
||||||
|
used), checking if it's the best yet recorded, incrementing the
|
||||||
|
stagnancy count if not. Convergence is asserted only if the number of
|
||||||
|
stagnant epochs exceeds ``stop_after_epochs``.
|
||||||
|
|
||||||
|
.. admonition:: Evaluation order
|
||||||
|
|
||||||
|
Convergence losses are recorded before the first training update,
|
||||||
|
so initial model states are appropriately benchmarked by the time
|
||||||
|
``_converged()`` is invoked.
|
||||||
|
|
||||||
|
If resuming training on the same dataset, one might expect only to
|
||||||
|
reset the stagnant epoch counter: you'll resume from the last
|
||||||
|
epoch, estimator state, and best seen loss, while allowed
|
||||||
|
``stop_after_epochs`` more chances for better validation.
|
||||||
|
|
||||||
|
If picking up training on a new dataset, even a training+validation
|
||||||
|
setting, resetting the best seen loss and best model state is
|
||||||
|
needed: you can't reliably compare the existing stats under new
|
||||||
|
data. It's somewhat ambiguous whether ``epoch`` absolutely must be
|
||||||
|
reset; you could continue logging metrics under the same named
|
||||||
|
session. But best practices would suggest restarting the epoch
|
||||||
|
count and have events logged under a new session heading when data
|
||||||
|
change.
|
||||||
|
"""
|
||||||
|
|
||||||
converged = False
|
converged = False
|
||||||
|
|
||||||
if epoch == 0 or self._conv_loss < self._best_val_loss:
|
if self._conv_loss < self._best_conv_loss:
|
||||||
self._best_val_loss = self._conv_loss
|
|
||||||
self._stagnant_epochs = 0
|
self._stagnant_epochs = 0
|
||||||
|
self._best_conv_loss = self._conv_loss
|
||||||
|
self._best_conv_epoch = self._epoch
|
||||||
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
||||||
else:
|
else:
|
||||||
self._stagnant_epochs += 1
|
self._stagnant_epochs += 1
|
||||||
|
|||||||
9
trainlib/utils/op.py
Normal file
9
trainlib/utils/op.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def r2_score(y: Tensor, y_hat: Tensor) -> Tensor:
|
||||||
|
ss_res = ((y - y_hat)**2).sum()
|
||||||
|
ss_tot = ((y - y.mean())**2).sum()
|
||||||
|
r2 = 1 - ss_res / ss_tot
|
||||||
|
|
||||||
|
return r2
|
||||||
Reference in New Issue
Block a user