diff --git a/pyproject.toml b/pyproject.toml index ebd9f53..c5bdeca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trainlib" -version = "0.3.0" +version = "0.3.1" description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." requires-python = ">=3.13" authors = [ diff --git a/trainlib/datasets/memory.py b/trainlib/datasets/memory.py index 3eb178b..9e2cc7e 100644 --- a/trainlib/datasets/memory.py +++ b/trainlib/datasets/memory.py @@ -108,12 +108,26 @@ class SlidingWindowDataset(TupleDataset[Tensor]): offset: int = 0, lookahead: int = 1, num_windows: int = 1, + pad_mode: str = "constant", + #fill_with: str = "zero", **kwargs: Unpack[DatasetKwargs], ) -> None: + """ + Parameters: + TODO: implement options for `fill_with`; currently just passing + through a `pad_mode` the Functional call, which does the job + + fill_with: strategy to use for padding values in windows + - `zero`: fill with zeros + - `left`: use nearest window column (repeat leftmost) + - `mean`: fill with the window mean + """ + self.lookback = lookback self.offset = offset self.lookahead = lookahead self.num_windows = num_windows + self.pad_mode = pad_mode super().__init__(domain, **kwargs) @@ -123,7 +137,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]): batch_index: int, ) -> list[tuple[Tensor, ...]]: """ - Backward pads first sequence over (lookback-1) length, and steps the + Backward pads window sequences over (lookback-1) length, and steps the remaining items forward by the lookahead. Batch data: @@ -166,7 +180,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]): exceeds the offset. To get windows starting with the first index at the left: we first set - out window size (call it L), determined by `lookback`. Then the + our window size (call it L), determined by `lookback`. Then the rightmost index we want will be `L-1`, which determines our `offset` setting. @@ -193,7 +207,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]): # for window sized `lb`, we pad with `lb-1` zeros. We then take off # the amount of our offset, which in the extreme cases does no # padding. - xip = F.pad(t, ((lb-1) - off, 0)) + xip = F.pad(t, ((lb-1) - off, 0), mode=self.pad_mode) # extract sliding windows over the padded tensor # unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for diff --git a/trainlib/estimators/rnn.py b/trainlib/estimators/rnn.py index aaf7ae9..d037ee6 100644 --- a/trainlib/estimators/rnn.py +++ b/trainlib/estimators/rnn.py @@ -8,6 +8,7 @@ from torch import nn, Tensor from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter +from trainlib.utils import op from trainlib.estimator import Estimator, EstimatorKwargs from trainlib.utils.type import OptimizerKwargs from trainlib.utils.module import get_grad_norm @@ -102,6 +103,7 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]): labels = kwargs["labels"] yield F.mse_loss(predictions, labels) + #yield F.l1_loss(predictions, labels) def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: with torch.no_grad(): @@ -109,12 +111,16 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]): predictions = self(**kwargs)[0] labels = kwargs["labels"] + + mse = F.mse_loss(predictions, labels).item() mae = F.l1_loss(predictions, labels).item() + r2 = op.r2_score(predictions, labels).item() return { - # "loss": loss, - "mse": loss, + "loss": loss, + "mse": mse, "mae": mae, + "r2": r2, "grad_norm": get_grad_norm(self) } @@ -377,7 +383,8 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]): # will be (B, T, C), applies indep at each time step across channels # self.dense_z = nn.Linear(layer_in_dim, self.output_dim) - # will be (B, C, T), applies indep at each time step across channels + # will be (B, Co, 1), applies indep at each channel across temporal dim + # size time steps self.dense_z = TDNNLayer( layer_in_dim, self.output_dim, @@ -438,6 +445,7 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]): predictions = predictions.squeeze(-1) yield F.mse_loss(predictions, labels, reduction="mean") + #yield F.l1_loss(predictions, labels) def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: with torch.no_grad(): @@ -445,11 +453,16 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]): predictions = self(**kwargs)[0].squeeze(-1) labels = kwargs["labels"] + + mse = F.mse_loss(predictions, labels).item() mae = F.l1_loss(predictions, labels).item() + r2 = op.r2_score(predictions, labels).item() return { - "mse": loss, + "loss": loss, + "mse": mse, "mae": mae, + "r2": r2, "grad_norm": get_grad_norm(self) } diff --git a/trainlib/plotter.py b/trainlib/plotter.py index 4fd9053..02d5650 100644 --- a/trainlib/plotter.py +++ b/trainlib/plotter.py @@ -64,14 +64,17 @@ class Plotter[Kw: EstimatorKwargs]: for i, loader in enumerate(self.dataloaders): label = self.dataloader_labels[i] - actual = torch.cat([ + actual = [ self.kw_to_actual(batch_kwargs).detach().cpu() for batch_kwargs in loader - ]) - output = torch.cat([ + ] + actual = torch.cat([ai.reshape(*([*ai.shape]+[1])[:2]) for ai in actual]) + + output = [ self.trainer.estimator(**batch_kwargs)[0].detach().cpu() for batch_kwargs in loader - ]) + ] + output = torch.cat([oi.reshape(*([*oi.shape]+[1])[:2]) for oi in output]) data_tuples.append((actual, output, label)) diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 93c54ef..670459a 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -104,18 +104,32 @@ class Trainer[I, Kw: EstimatorKwargs]: self.reset() - def reset(self) -> None: + def reset(self, resume: bool = False) -> None: """ Set initial tracking parameters for the primary training loop. + + Parameters: + resume: if ``True``, just resets the stagnant epoch counter, with + the aims of continuing any existing training state under + resumed ``train()`` call. This should likely only be set when + training is continued on the same dataset and the goal is to + resume convergence loss-based scoring for a fresh set of + epochs. If even that element of the training loop should resume + (which should only happen if a training loop was interrupted or + a max epoch limit was reached), then this method shouldn't be + called at all between ``train()`` invocations. """ - self._epoch: int = 0 - self._summary = defaultdict(lambda: defaultdict(list)) - - self._conv_loss = float("inf") - self._best_conv_loss = float("inf") self._stagnant_epochs = 0 - self._best_model_state_dict: dict[str, Any] = {} + + if not resume: + self._epoch: int = 0 + self._summary = defaultdict(lambda: defaultdict(list)) + + self._conv_loss = float("inf") + self._best_conv_loss = float("inf") + self._best_conv_epoch = 0 + self._best_model_state_dict: dict[str, Any] = {} def _train_epoch( self, @@ -459,15 +473,18 @@ class Trainer[I, Kw: EstimatorKwargs]: progress_bar = tqdm(train_loader, unit="batch") # evaluate model on dataloaders once before training starts - self._eval_loaders(train_loader, val_loader, aux_loaders, progress_bar) + train_loss, val_loss, *_ = self._eval_loaders( + train_loader, val_loader, aux_loaders, progress_bar + ) + conv_loss = val_loss if val_loss else train_loss + self._conv_loss = sum(conv_loss) / len(conv_loss) + optimizers = self.estimator.optimizers(lr=lr, eps=eps) - while self._epoch < max_epochs and not self._converged( - self._epoch, stop_after_epochs - ): + while self._epoch < max_epochs and not self._converged(stop_after_epochs): self._epoch += 1 - train_frac = f"{self._epoch}/{max_epochs}" - stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" + #train_frac = f"{self._epoch}/{max_epochs}" + #stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" #print(f"Training epoch {train_frac}...") #print(f"Stagnant epochs {stag_frac}...") @@ -495,12 +512,43 @@ class Trainer[I, Kw: EstimatorKwargs]: return self.estimator - def _converged(self, epoch: int, stop_after_epochs: int) -> bool: + def _converged(self, stop_after_epochs: int) -> bool: + """ + Check if model has converged. + + This method looks at the current "convergence loss" (validation-based + if a val set is provided to ``train()``, otherwise the training loss is + used), checking if it's the best yet recorded, incrementing the + stagnancy count if not. Convergence is asserted only if the number of + stagnant epochs exceeds ``stop_after_epochs``. + + .. admonition:: Evaluation order + + Convergence losses are recorded before the first training update, + so initial model states are appropriately benchmarked by the time + ``_converged()`` is invoked. + + If resuming training on the same dataset, one might expect only to + reset the stagnant epoch counter: you'll resume from the last + epoch, estimator state, and best seen loss, while allowed + ``stop_after_epochs`` more chances for better validation. + + If picking up training on a new dataset, even a training+validation + setting, resetting the best seen loss and best model state is + needed: you can't reliably compare the existing stats under new + data. It's somewhat ambiguous whether ``epoch`` absolutely must be + reset; you could continue logging metrics under the same named + session. But best practices would suggest restarting the epoch + count and have events logged under a new session heading when data + change. + """ + converged = False - if epoch == 0 or self._conv_loss < self._best_val_loss: - self._best_val_loss = self._conv_loss + if self._conv_loss < self._best_conv_loss: self._stagnant_epochs = 0 + self._best_conv_loss = self._conv_loss + self._best_conv_epoch = self._epoch self._best_model_state_dict = deepcopy(self.estimator.state_dict()) else: self._stagnant_epochs += 1 diff --git a/trainlib/utils/op.py b/trainlib/utils/op.py new file mode 100644 index 0000000..3a2a3b3 --- /dev/null +++ b/trainlib/utils/op.py @@ -0,0 +1,9 @@ +from torch import Tensor + + +def r2_score(y: Tensor, y_hat: Tensor) -> Tensor: + ss_res = ((y - y_hat)**2).sum() + ss_tot = ((y - y.mean())**2).sum() + r2 = 1 - ss_res / ss_tot + + return r2 diff --git a/uv.lock b/uv.lock index 54efd38..cce8dd1 100644 --- a/uv.lock +++ b/uv.lock @@ -1659,7 +1659,7 @@ wheels = [ [[package]] name = "trainlib" -version = "0.2.1" +version = "0.3.1" source = { editable = "." } dependencies = [ { name = "colorama" },