From 85d176862e43c95fd846cd6009cd025f4eea98a5 Mon Sep 17 00:00:00 2001 From: smgr Date: Mon, 16 Mar 2026 03:20:21 -0700 Subject: [PATCH] consolidate Trainer object, synchronize event logging --- trainlib/estimator.py | 2 +- trainlib/estimators/mlp.py | 12 +-- trainlib/estimators/rnn.py | 36 +++---- trainlib/trainer.py | 197 ++++++++++++++++++------------------- 4 files changed, 123 insertions(+), 124 deletions(-) diff --git a/trainlib/estimator.py b/trainlib/estimator.py index 1eae175..c9cf9d7 100644 --- a/trainlib/estimator.py +++ b/trainlib/estimator.py @@ -164,7 +164,7 @@ class Estimator[Kw: EstimatorKwargs](nn.Module): self, writer: SummaryWriter, step: int | None = None, - val: bool = False, + group: str | None = None, **kwargs: Unpack[Kw], ) -> None: """ diff --git a/trainlib/estimators/mlp.py b/trainlib/estimators/mlp.py index 221a191..d041dde 100644 --- a/trainlib/estimators/mlp.py +++ b/trainlib/estimators/mlp.py @@ -21,7 +21,7 @@ class MLPKwargs(EstimatorKwargs): labels: NotRequired[Tensor] -class MLP[K: MLPKwargs](Estimator[K]): +class MLP[Kw: MLPKwargs](Estimator[Kw]): """ Base MLP architecture. """ @@ -82,19 +82,19 @@ class MLP[K: MLPKwargs](Estimator[K]): max=1.0, ) - def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]: + def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]: inputs = kwargs["inputs"] x = self._net(inputs) return (x,) - def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]: + def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]: predictions = self(**kwargs)[0] labels = kwargs["labels"] yield F.mse_loss(predictions, labels) - def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]: + def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: with torch.no_grad(): loss = next(self.loss(**kwargs)).item() @@ -135,8 +135,8 @@ class MLP[K: MLPKwargs](Estimator[K]): self, writer: SummaryWriter, step: int | None = None, - val: bool = False, - **kwargs: Unpack[K], + group: str | None = None, + **kwargs: Unpack[Kw], ) -> None: return None diff --git a/trainlib/estimators/rnn.py b/trainlib/estimators/rnn.py index 64bfcae..aaf7ae9 100644 --- a/trainlib/estimators/rnn.py +++ b/trainlib/estimators/rnn.py @@ -21,7 +21,7 @@ class RNNKwargs(EstimatorKwargs): labels: NotRequired[Tensor] -class LSTM[K: RNNKwargs](Estimator[K]): +class LSTM[Kw: RNNKwargs](Estimator[Kw]): """ Base RNN architecture. """ @@ -85,7 +85,7 @@ class LSTM[K: RNNKwargs](Estimator[K]): max=1.0, ) - def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]: + def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]: inputs = kwargs["inputs"] # data shaped (B, C, T); map to (B, T, C) @@ -97,13 +97,13 @@ class LSTM[K: RNNKwargs](Estimator[K]): return z[:, -1, :], hidden - def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]: + def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]: predictions = self(**kwargs)[0] labels = kwargs["labels"] yield F.mse_loss(predictions, labels) - def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]: + def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: with torch.no_grad(): loss = next(self.loss(**kwargs)).item() @@ -145,8 +145,8 @@ class LSTM[K: RNNKwargs](Estimator[K]): self, writer: SummaryWriter, step: int | None = None, - val: bool = False, - **kwargs: Unpack[K], + group: str | None = None, + **kwargs: Unpack[Kw], ) -> None: return None @@ -165,7 +165,7 @@ class MultiheadLSTMKwargs(EstimatorKwargs): auxiliary: NotRequired[Tensor] -class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]): +class MultiheadLSTM[Kw: MultiheadLSTMKwargs](Estimator[Kw]): def __init__( self, input_dim: int, @@ -223,7 +223,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]): max=1.0, ) - def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]: + def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]: inputs = kwargs["inputs"] # data shaped (B, C, T); map to (B, T, C) @@ -237,7 +237,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]): return z[:, -1, :], zs[:, -1, :] - def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]: + def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]: pred, pred_aux = self(**kwargs) labels = kwargs["labels"] aux_labels = kwargs.get("auxiliary") @@ -247,7 +247,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]): else: yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels) - def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]: + def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: return { "grad_norm": get_grad_norm(self) } @@ -279,8 +279,8 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]): self, writer: SummaryWriter, step: int | None = None, - val: bool = False, - **kwargs: Unpack[K], + group: str | None = None, + **kwargs: Unpack[Kw], ) -> None: return None @@ -293,7 +293,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]): logger.info(f"| > {self.output_dim=}") -class ConvGRU[K: RNNKwargs](Estimator[K]): +class ConvGRU[Kw: RNNKwargs](Estimator[Kw]): """ Base recurrent convolutional architecture. @@ -404,7 +404,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]): max=1.0, ) - def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]: + def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]: inputs = kwargs["inputs"] # embedding shaped (B, C, T) @@ -430,7 +430,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]): return (z,) - def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]: + def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]: predictions = self(**kwargs)[0] labels = kwargs["labels"] @@ -439,7 +439,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]): yield F.mse_loss(predictions, labels, reduction="mean") - def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]: + def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: with torch.no_grad(): loss = next(self.loss(**kwargs)).item() @@ -480,8 +480,8 @@ class ConvGRU[K: RNNKwargs](Estimator[K]): self, writer: SummaryWriter, step: int | None = None, - val: bool = False, - **kwargs: Unpack[K], + group: str | None = None, + **kwargs: Unpack[Kw], ) -> None: return None diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 32e91ea..16e6006 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -105,7 +105,8 @@ class Trainer[I, K: EstimatorKwargs]: """ self._epoch: int = 1 - self._event_log = defaultdict(lambda: defaultdict(dict)) + self._summary = defaultdict(lambda: defaultdict(dict)) + self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) self._val_loss = float("inf") self._best_val_loss = float("inf") @@ -117,7 +118,6 @@ class Trainer[I, K: EstimatorKwargs]: train_loader: DataLoader, batch_estimator_map: Callable[[I, Self], K], optimizers: tuple[Optimizer, ...], - writer: SummaryWriter, max_grad_norm: float | None = None, ) -> list[float]: """ @@ -154,15 +154,8 @@ class Trainer[I, K: EstimatorKwargs]: with tqdm(train_loader, unit="batch") as batches: for i, batch_data in enumerate(batches): est_kwargs = batch_estimator_map(batch_data, self) - - # one-time logging - if self._step == 0: - writer.add_graph( - ModelWrapper(self.estimator), - est_kwargs - ) - losses = self.estimator.loss(**est_kwargs) + for o_idx, (loss, optimizer) in enumerate( zip(losses, optimizers, strict=True) ): @@ -186,9 +179,6 @@ class Trainer[I, K: EstimatorKwargs]: loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1)) batches.set_postfix(loss=f"{loss_avg:8.2f}") - # grab the assured `inputs` key to measure num samples - self._step += len(est_kwargs["inputs"]) - # step estimator hyperparam schedules self.estimator.epoch_step() @@ -199,7 +189,6 @@ class Trainer[I, K: EstimatorKwargs]: loader: DataLoader, batch_estimator_map: Callable[[I, Self], K], loader_label: str, - writer: SummaryWriter, ) -> list[float]: """ Perform and record validation scores for a single epoch. @@ -230,8 +219,22 @@ class Trainer[I, K: EstimatorKwargs]: with tqdm(loader, unit="batch") as batches: for i, batch_data in enumerate(batches): est_kwargs = batch_estimator_map(batch_data, self) - losses = self.estimator.loss(**est_kwargs) + + # one-time logging + if self._epoch == 0: + self._writer.add_graph( + ModelWrapper(self.estimator), est_kwargs + ) + # once-per-epoch logging + if i == 0: + self.estimator.epoch_write( + self._writer, + step=self._epoch, + group=loader_label, + **est_kwargs + ) + loss_items = [] for o_idx, loss in enumerate(losses): if len(loss_sums) <= o_idx: @@ -247,17 +250,12 @@ class Trainer[I, K: EstimatorKwargs]: # log individual loss terms after each batch for o_idx, loss_item in enumerate(loss_items): - self._event_log[run_prefix][loader][name][step] = value - self._add_summary_item( - loader_label, f"loss_{o_idx}", loss_item - ) + self._log_event(loader_label, f"loss_{o_idx}", loss_item) # log metrics for batch estimator_metrics = self.estimator.metrics(**est_kwargs) for metric_name, metric_value in estimator_metrics.items(): - self._add_summary_item( - f"{loader_label}_{metric_name}", metric_value - ) + self._log_event(loader_label, metric_name, metric_value) return loss_sums @@ -265,8 +263,7 @@ class Trainer[I, K: EstimatorKwargs]: self, loaders: list[DataLoader], batch_estimator_map: Callable[[I, Self], K], - labels: list[str], - writer: SummaryWriter, + loader_labels: list[str], ) -> dict[str, list[float]]: """ Evaluate estimator over each provided dataloader. @@ -284,8 +281,8 @@ class Trainer[I, K: EstimatorKwargs]: """ return { - label: self._eval_epoch(loader, label, batch_estimator_map, writer) - for loader, label in zip(loaders, labels, strict=True) + label: self._eval_epoch(loader, batch_estimator_map, label) + for loader, label in zip(loaders, loader_labels, strict=True) } def train[B]( @@ -307,7 +304,7 @@ class Trainer[I, K: EstimatorKwargs]: summarize_every: int = 1, chkpt_every: int = 1, resume_latest: bool = False, - run_prefix: str | None = None, + session_name: str | None = None, summary_writer: SummaryWriter | None = None, aux_loaders: list[DataLoader] | None = None, aux_loader_labels: list[str] | None = None, @@ -362,10 +359,42 @@ class Trainer[I, K: EstimatorKwargs]: customized) doesn't consistently yield a known type shape, however, so it's not appropriate to use ``I`` as the callable param type. - .. todo:: + .. admonition:: On session management - - Align eval stage of both train and val; currently train is from - before updates, val is from after at a given epoch + This method works around an implicit notion of training sessions. + Estimators are set during instantiation and effectively coupled + with ``Trainer`` instances, but datasets can be supplied + dynamically here. One can, for instance, run under one condition + (specific dataset, number of epochs, etc), then resume later under + another. Relevant details persist across calls: the estimator is + still attached, best val scores stowed, current epoch tracked. By + default, new ``session_names`` are always generated, but you can + write to the same TB location if you using the same + ``session_name`` across calls; that's about as close to a direct + training resume as you could want. + + If restarting training on new datasets, including short + fine-tuning on training-plus-validation data, it's often sensible + to call ``.reset()`` between ``.train()`` calls. While the same + estimator will be used, tracked variables will be wiped; subsequent + model updates take place under a fresh epoch, no val losses, and be + logged under a separate TB session. This is the general approach to + "piecemeal" training, i.e., incremental model updates under varying + conditions (often just data changes). + + .. warning:: + + Validation convergence when there are multiple losses may be + ambiguous. These are cases where certain parameter sets are + optimized independently; the sum over these losses may not reflect + expected or consistent behavior. For instance, we may record a low + cumulative loss early with a small 1st loss and moderate 2nd loss, + while later encountering a moderate 1st lost and small 2nd loss. We + might prefer the latter case, while ``_converged()`` will stick to + the former -- we need to consider possible weighting across losses, + or storing possibly several best models (e.g., for each loss, the + model that scores best, plus the one scoring best cumulatively, + etc). Parameters: dataset: dataset to train the estimator @@ -402,16 +431,14 @@ class Trainer[I, K: EstimatorKwargs]: logger.info(f"| > with device: {self.device}") logger.info(f"| > core count: {os.cpu_count()}") - writer: SummaryWriter - run_prefix = run_prefix or str(int(time.time())) - if summary_writer is None: - writer = SummaryWriter(f"{Path(self.tblog_dir, run_prefix)}") - else: - writer = summary_writer + self._session_name = session_name or str(int(time.time())) + tblog_path = Path(self.tblog_dir, self._session_name) + self._writer = summary_writer or SummaryWriter(f"{tblog_path}") aux_loaders = aux_loaders or [] aux_loader_labels = aux_loader_labels or [] + optimizers = self.estimator.optimizers(lr=lr, eps=eps) train_loader, val_loader = self.get_dataloaders( dataset, batch_size, @@ -426,12 +453,8 @@ class Trainer[I, K: EstimatorKwargs]: loader_labels = ["train", "val", *aux_loader_labels] # evaluate model on dataloaders once before training starts - self._eval_loaders(loaders, batch_estimator_map, # loader_labels, writer) + self._eval_loaders(loaders, batch_estimator_map, loader_labels) - optimizers = self.estimator.optimizers(lr=lr, eps=eps) - - self._step = 0 - self._epoch = 1 # start from 1 for logging convenience while self._epoch <= max_epochs and not self._converged( self._epoch, stop_after_epochs ): @@ -446,41 +469,22 @@ class Trainer[I, K: EstimatorKwargs]: batch_estimator_map, optimizers, max_grad_norm, - writer, ) - self._add_summary_item( - "epoch_time_sec", time.time() - epoch_start_time - ) - # writer steps are measured in number of samples; log epoch as an - # alternative step resolution for interpreting event timing - self._add_summary_item("step", float(self._step)) - - - # once-per-epoch logging - self.estimator.epoch_write( - writer, - step=self._step, - val=True, - **est_kwargs - ) - - - - + epoch_end_time = time.time() - epoch_start_time + self._log_event("train", "epoch_duration", epoch_end_time) loss_sum_map = self._eval_loaders( - loaders, loader_labels, batch_estimator_map, writer + loaders, + batch_estimator_map, + loader_labels, ) - # convergence of multiple losses may be ambiguous val_loss_sums = loss_sum_map["val"] self._val_loss = sum(val_loss_sums) / len(val_loader) if self._epoch % summarize_every == 0: - self._summarize(writer, self._epoch, run_prefix) - + self._summarize() if self._epoch % chkpt_every == 0: - self.save_model(self.chkpt_dir, self._epoch, run_prefix) - + self.save_model() self._epoch += 1 return self.estimator @@ -514,6 +518,10 @@ class Trainer[I, K: EstimatorKwargs]: ) -> tuple[DataLoader, DataLoader]: """ Create training and validation dataloaders for the provided dataset. + + .. todo:: + + Decide on policy for empty val dataloaders """ if dataset_split_kwargs is None: @@ -571,12 +579,7 @@ class Trainer[I, K: EstimatorKwargs]: return train_loader, val_loader - def _summarize( - self, - writer: SummaryWriter, - epoch: int, - run_prefix: str - ) -> None: + def _summarize(self) -> None: """ Flush the training summary to the TensorBoard summary writer. @@ -588,21 +591,24 @@ class Trainer[I, K: EstimatorKwargs]: averages from epochs 1-10. """ - summary_values = defaultdict(list) - for name, records in self._summary.items(): - for value, step in records: - writer.add_scalar(name, value, step) - # writer.add_scalar(name, value, self._epoch) - summary_values[name].append(value) - self._event_log[run_prefix][name][step] = value + epoch_values = defaultdict(lambda: defaultdict(list)) + for group, records in self._summary.items(): + for name, steps in records.items(): + for step, value in steps.items(): + self._writer.add_scalar(f"{group}-{name}", value, step) + if step == self._epoch: + epoch_values[group][name].append(value) - print(f"==== Epoch [{epoch}] summary ====") - for name, values in summary_values.items(): - mean_value = torch.tensor(values).mean().item() - print(f"> ({len(values)}) {name} :: {mean_value:.2f}") + print(f"==== Epoch [{self._epoch}] summary ====") + for group, records in epoch_values.items(): + for name, values in records.items(): + mean_value = torch.tensor(values).mean().item() + print( + f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}" + ) - writer.flush() - self._summary = defaultdict(list) + self._writer.flush() + self._summary = defaultdict(lambda: defaultdict(dict)) def _get_optimizer_parameters( self, @@ -615,11 +621,9 @@ class Trainer[I, K: EstimatorKwargs]: if param.grad is not None ] - def _add_summary_item(self, group: str, name: str, value: float) -> None: - # self._summary[name].append((value, self._step)) - self._summary[name].append((value, self._epoch)) - - self._event_log[run_prefix][name][step] = value + def _log_event(self, group: str, name: str, value: float) -> None: + self._summary[group][name][self._epoch] = value + self._event_log[self._session_name][group][name][self._epoch] = value def get_batch_outputs[B]( self, @@ -646,12 +650,7 @@ class Trainer[I, K: EstimatorKwargs]: return metrics - def save_model( - self, - chkpt_dir: str | Path, - epoch: int, - run_prefix: str, - ) -> None: + def save_model(self) -> None: """ Save a model checkpoint. """ @@ -661,9 +660,9 @@ class Trainer[I, K: EstimatorKwargs]: model_buff.seek(0) model_class = self.estimator.__class__.__name__ - chkpt_name = f"m_{model_class}-e_{epoch}.pth" + chkpt_name = f"m_{model_class}-e_{self._epoch}.pth" - chkpt_dir = Path(chkpt_dir, run_prefix) + chkpt_dir = Path(self.chkpt_dir, self._session_name) chkpt_path = Path(chkpt_dir, chkpt_name) chkpt_dir.mkdir(parents=True, exist_ok=True)