consolidate Trainer object, synchronize event logging

This commit is contained in:
2026-03-16 03:20:21 -07:00
parent 9a0b0e5626
commit 85d176862e
4 changed files with 123 additions and 124 deletions

View File

@@ -164,7 +164,7 @@ class Estimator[Kw: EstimatorKwargs](nn.Module):
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
"""

View File

@@ -21,7 +21,7 @@ class MLPKwargs(EstimatorKwargs):
labels: NotRequired[Tensor]
class MLP[K: MLPKwargs](Estimator[K]):
class MLP[Kw: MLPKwargs](Estimator[Kw]):
"""
Base MLP architecture.
"""
@@ -82,19 +82,19 @@ class MLP[K: MLPKwargs](Estimator[K]):
max=1.0,
)
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
x = self._net(inputs)
return (x,)
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
yield F.mse_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
@@ -135,8 +135,8 @@ class MLP[K: MLPKwargs](Estimator[K]):
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[K],
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
return None

View File

@@ -21,7 +21,7 @@ class RNNKwargs(EstimatorKwargs):
labels: NotRequired[Tensor]
class LSTM[K: RNNKwargs](Estimator[K]):
class LSTM[Kw: RNNKwargs](Estimator[Kw]):
"""
Base RNN architecture.
"""
@@ -85,7 +85,7 @@ class LSTM[K: RNNKwargs](Estimator[K]):
max=1.0,
)
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# data shaped (B, C, T); map to (B, T, C)
@@ -97,13 +97,13 @@ class LSTM[K: RNNKwargs](Estimator[K]):
return z[:, -1, :], hidden
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
yield F.mse_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
@@ -145,8 +145,8 @@ class LSTM[K: RNNKwargs](Estimator[K]):
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[K],
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
return None
@@ -165,7 +165,7 @@ class MultiheadLSTMKwargs(EstimatorKwargs):
auxiliary: NotRequired[Tensor]
class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
class MultiheadLSTM[Kw: MultiheadLSTMKwargs](Estimator[Kw]):
def __init__(
self,
input_dim: int,
@@ -223,7 +223,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
max=1.0,
)
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# data shaped (B, C, T); map to (B, T, C)
@@ -237,7 +237,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
return z[:, -1, :], zs[:, -1, :]
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
pred, pred_aux = self(**kwargs)
labels = kwargs["labels"]
aux_labels = kwargs.get("auxiliary")
@@ -247,7 +247,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
else:
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
return {
"grad_norm": get_grad_norm(self)
}
@@ -279,8 +279,8 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[K],
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
return None
@@ -293,7 +293,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
logger.info(f"| > {self.output_dim=}")
class ConvGRU[K: RNNKwargs](Estimator[K]):
class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
"""
Base recurrent convolutional architecture.
@@ -404,7 +404,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
max=1.0,
)
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# embedding shaped (B, C, T)
@@ -430,7 +430,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
return (z,)
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
@@ -439,7 +439,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
yield F.mse_loss(predictions, labels, reduction="mean")
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
@@ -480,8 +480,8 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[K],
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
return None

View File

@@ -105,7 +105,8 @@ class Trainer[I, K: EstimatorKwargs]:
"""
self._epoch: int = 1
self._event_log = defaultdict(lambda: defaultdict(dict))
self._summary = defaultdict(lambda: defaultdict(dict))
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
self._val_loss = float("inf")
self._best_val_loss = float("inf")
@@ -117,7 +118,6 @@ class Trainer[I, K: EstimatorKwargs]:
train_loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
optimizers: tuple[Optimizer, ...],
writer: SummaryWriter,
max_grad_norm: float | None = None,
) -> list[float]:
"""
@@ -154,15 +154,8 @@ class Trainer[I, K: EstimatorKwargs]:
with tqdm(train_loader, unit="batch") as batches:
for i, batch_data in enumerate(batches):
est_kwargs = batch_estimator_map(batch_data, self)
# one-time logging
if self._step == 0:
writer.add_graph(
ModelWrapper(self.estimator),
est_kwargs
)
losses = self.estimator.loss(**est_kwargs)
for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True)
):
@@ -186,9 +179,6 @@ class Trainer[I, K: EstimatorKwargs]:
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
# grab the assured `inputs` key to measure num samples
self._step += len(est_kwargs["inputs"])
# step estimator hyperparam schedules
self.estimator.epoch_step()
@@ -199,7 +189,6 @@ class Trainer[I, K: EstimatorKwargs]:
loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
loader_label: str,
writer: SummaryWriter,
) -> list[float]:
"""
Perform and record validation scores for a single epoch.
@@ -230,8 +219,22 @@ class Trainer[I, K: EstimatorKwargs]:
with tqdm(loader, unit="batch") as batches:
for i, batch_data in enumerate(batches):
est_kwargs = batch_estimator_map(batch_data, self)
losses = self.estimator.loss(**est_kwargs)
# one-time logging
if self._epoch == 0:
self._writer.add_graph(
ModelWrapper(self.estimator), est_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=loader_label,
**est_kwargs
)
loss_items = []
for o_idx, loss in enumerate(losses):
if len(loss_sums) <= o_idx:
@@ -247,17 +250,12 @@ class Trainer[I, K: EstimatorKwargs]:
# log individual loss terms after each batch
for o_idx, loss_item in enumerate(loss_items):
self._event_log[run_prefix][loader][name][step] = value
self._add_summary_item(
loader_label, f"loss_{o_idx}", loss_item
)
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(
f"{loader_label}_{metric_name}", metric_value
)
self._log_event(loader_label, metric_name, metric_value)
return loss_sums
@@ -265,8 +263,7 @@ class Trainer[I, K: EstimatorKwargs]:
self,
loaders: list[DataLoader],
batch_estimator_map: Callable[[I, Self], K],
labels: list[str],
writer: SummaryWriter,
loader_labels: list[str],
) -> dict[str, list[float]]:
"""
Evaluate estimator over each provided dataloader.
@@ -284,8 +281,8 @@ class Trainer[I, K: EstimatorKwargs]:
"""
return {
label: self._eval_epoch(loader, label, batch_estimator_map, writer)
for loader, label in zip(loaders, labels, strict=True)
label: self._eval_epoch(loader, batch_estimator_map, label)
for loader, label in zip(loaders, loader_labels, strict=True)
}
def train[B](
@@ -307,7 +304,7 @@ class Trainer[I, K: EstimatorKwargs]:
summarize_every: int = 1,
chkpt_every: int = 1,
resume_latest: bool = False,
run_prefix: str | None = None,
session_name: str | None = None,
summary_writer: SummaryWriter | None = None,
aux_loaders: list[DataLoader] | None = None,
aux_loader_labels: list[str] | None = None,
@@ -362,10 +359,42 @@ class Trainer[I, K: EstimatorKwargs]:
customized) doesn't consistently yield a known type shape, however,
so it's not appropriate to use ``I`` as the callable param type.
.. todo::
.. admonition:: On session management
- Align eval stage of both train and val; currently train is from
before updates, val is from after at a given epoch
This method works around an implicit notion of training sessions.
Estimators are set during instantiation and effectively coupled
with ``Trainer`` instances, but datasets can be supplied
dynamically here. One can, for instance, run under one condition
(specific dataset, number of epochs, etc), then resume later under
another. Relevant details persist across calls: the estimator is
still attached, best val scores stowed, current epoch tracked. By
default, new ``session_names`` are always generated, but you can
write to the same TB location if you using the same
``session_name`` across calls; that's about as close to a direct
training resume as you could want.
If restarting training on new datasets, including short
fine-tuning on training-plus-validation data, it's often sensible
to call ``.reset()`` between ``.train()`` calls. While the same
estimator will be used, tracked variables will be wiped; subsequent
model updates take place under a fresh epoch, no val losses, and be
logged under a separate TB session. This is the general approach to
"piecemeal" training, i.e., incremental model updates under varying
conditions (often just data changes).
.. warning::
Validation convergence when there are multiple losses may be
ambiguous. These are cases where certain parameter sets are
optimized independently; the sum over these losses may not reflect
expected or consistent behavior. For instance, we may record a low
cumulative loss early with a small 1st loss and moderate 2nd loss,
while later encountering a moderate 1st lost and small 2nd loss. We
might prefer the latter case, while ``_converged()`` will stick to
the former -- we need to consider possible weighting across losses,
or storing possibly several best models (e.g., for each loss, the
model that scores best, plus the one scoring best cumulatively,
etc).
Parameters:
dataset: dataset to train the estimator
@@ -402,16 +431,14 @@ class Trainer[I, K: EstimatorKwargs]:
logger.info(f"| > with device: {self.device}")
logger.info(f"| > core count: {os.cpu_count()}")
writer: SummaryWriter
run_prefix = run_prefix or str(int(time.time()))
if summary_writer is None:
writer = SummaryWriter(f"{Path(self.tblog_dir, run_prefix)}")
else:
writer = summary_writer
self._session_name = session_name or str(int(time.time()))
tblog_path = Path(self.tblog_dir, self._session_name)
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
aux_loaders = aux_loaders or []
aux_loader_labels = aux_loader_labels or []
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
train_loader, val_loader = self.get_dataloaders(
dataset,
batch_size,
@@ -426,12 +453,8 @@ class Trainer[I, K: EstimatorKwargs]:
loader_labels = ["train", "val", *aux_loader_labels]
# evaluate model on dataloaders once before training starts
self._eval_loaders(loaders, batch_estimator_map, # loader_labels, writer)
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
self._step = 0
self._epoch = 1 # start from 1 for logging convenience
while self._epoch <= max_epochs and not self._converged(
self._epoch, stop_after_epochs
):
@@ -446,41 +469,22 @@ class Trainer[I, K: EstimatorKwargs]:
batch_estimator_map,
optimizers,
max_grad_norm,
writer,
)
self._add_summary_item(
"epoch_time_sec", time.time() - epoch_start_time
)
# writer steps are measured in number of samples; log epoch as an
# alternative step resolution for interpreting event timing
self._add_summary_item("step", float(self._step))
# once-per-epoch logging
self.estimator.epoch_write(
writer,
step=self._step,
val=True,
**est_kwargs
)
epoch_end_time = time.time() - epoch_start_time
self._log_event("train", "epoch_duration", epoch_end_time)
loss_sum_map = self._eval_loaders(
loaders, loader_labels, batch_estimator_map, writer
loaders,
batch_estimator_map,
loader_labels,
)
# convergence of multiple losses may be ambiguous
val_loss_sums = loss_sum_map["val"]
self._val_loss = sum(val_loss_sums) / len(val_loader)
if self._epoch % summarize_every == 0:
self._summarize(writer, self._epoch, run_prefix)
self._summarize()
if self._epoch % chkpt_every == 0:
self.save_model(self.chkpt_dir, self._epoch, run_prefix)
self.save_model()
self._epoch += 1
return self.estimator
@@ -514,6 +518,10 @@ class Trainer[I, K: EstimatorKwargs]:
) -> tuple[DataLoader, DataLoader]:
"""
Create training and validation dataloaders for the provided dataset.
.. todo::
Decide on policy for empty val dataloaders
"""
if dataset_split_kwargs is None:
@@ -571,12 +579,7 @@ class Trainer[I, K: EstimatorKwargs]:
return train_loader, val_loader
def _summarize(
self,
writer: SummaryWriter,
epoch: int,
run_prefix: str
) -> None:
def _summarize(self) -> None:
"""
Flush the training summary to the TensorBoard summary writer.
@@ -588,21 +591,24 @@ class Trainer[I, K: EstimatorKwargs]:
averages from epochs 1-10.
"""
summary_values = defaultdict(list)
for name, records in self._summary.items():
for value, step in records:
writer.add_scalar(name, value, step)
# writer.add_scalar(name, value, self._epoch)
summary_values[name].append(value)
self._event_log[run_prefix][name][step] = value
epoch_values = defaultdict(lambda: defaultdict(list))
for group, records in self._summary.items():
for name, steps in records.items():
for step, value in steps.items():
self._writer.add_scalar(f"{group}-{name}", value, step)
if step == self._epoch:
epoch_values[group][name].append(value)
print(f"==== Epoch [{epoch}] summary ====")
for name, values in summary_values.items():
mean_value = torch.tensor(values).mean().item()
print(f"> ({len(values)}) {name} :: {mean_value:.2f}")
print(f"==== Epoch [{self._epoch}] summary ====")
for group, records in epoch_values.items():
for name, values in records.items():
mean_value = torch.tensor(values).mean().item()
print(
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
)
writer.flush()
self._summary = defaultdict(list)
self._writer.flush()
self._summary = defaultdict(lambda: defaultdict(dict))
def _get_optimizer_parameters(
self,
@@ -615,11 +621,9 @@ class Trainer[I, K: EstimatorKwargs]:
if param.grad is not None
]
def _add_summary_item(self, group: str, name: str, value: float) -> None:
# self._summary[name].append((value, self._step))
self._summary[name].append((value, self._epoch))
self._event_log[run_prefix][name][step] = value
def _log_event(self, group: str, name: str, value: float) -> None:
self._summary[group][name][self._epoch] = value
self._event_log[self._session_name][group][name][self._epoch] = value
def get_batch_outputs[B](
self,
@@ -646,12 +650,7 @@ class Trainer[I, K: EstimatorKwargs]:
return metrics
def save_model(
self,
chkpt_dir: str | Path,
epoch: int,
run_prefix: str,
) -> None:
def save_model(self) -> None:
"""
Save a model checkpoint.
"""
@@ -661,9 +660,9 @@ class Trainer[I, K: EstimatorKwargs]:
model_buff.seek(0)
model_class = self.estimator.__class__.__name__
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
chkpt_name = f"m_{model_class}-e_{self._epoch}.pth"
chkpt_dir = Path(chkpt_dir, run_prefix)
chkpt_dir = Path(self.chkpt_dir, self._session_name)
chkpt_path = Path(chkpt_dir, chkpt_name)
chkpt_dir.mkdir(parents=True, exist_ok=True)