2 Commits

Author SHA1 Message Date
fdccb4c5eb update train loop eval logic 2026-03-31 22:52:27 -07:00
ba0c804d5e implement CLI trainer utility, update progress bar logging 2026-03-25 02:25:28 -07:00
19 changed files with 543 additions and 96 deletions

29
example/example.json Normal file
View File

@@ -0,0 +1,29 @@
{
"estimator_name": "mlp",
"dataset_name": "random_xy_dataset",
"dataloader_name": "supervised_data_loader",
"estimator_kwargs": {
"input_dim": 4,
"output_dim": 2
},
"dataset_kwargs": {
"num_samples": 100000,
"preload": true,
"input_dim": 4,
"output_dim": 2
},
"dataset_split_fracs": {
"train": 0.4,
"val": 0.3,
"aux": [0.3]
},
"dataloader_kwargs": {
"batch_size": 16
},
"train_kwargs": {
"summarize_every": 20,
"max_epochs": 100,
"stop_after_epochs": 100
},
"load_only": false
}

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "trainlib"
version = "0.2.0"
version = "0.3.1"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13"
authors = [

18
trainlib/__main__.py Normal file
View File

@@ -0,0 +1,18 @@
import logging
from trainlib.cli import create_parser
def main() -> None:
parser = create_parser()
args = parser.parse_args()
# skim off log level to handle higher-level option
if hasattr(args, "log_level") and args.log_level is not None:
logging.basicConfig(level=args.log_level)
args.func(args) if "func" in args else parser.print_help()
if __name__ == "__main__":
main()

26
trainlib/cli/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
import logging
import argparse
from trainlib.cli import train
logger: logging.Logger = logging.getLogger(__name__)
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="trainlib cli",
# formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--log-level",
type=int,
metavar="int",
choices=[10, 20, 30, 40, 50],
help="Log level: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL",
)
subparsers = parser.add_subparsers(help="subcommand help")
train.register_parser(subparsers)
return parser

164
trainlib/cli/train.py Normal file
View File

@@ -0,0 +1,164 @@
import gc
import json
import argparse
from typing import Any
from argparse import _SubParsersAction
import torch
from trainlib.trainer import Trainer
from trainlib.datasets import dataset_map
from trainlib.estimator import Estimator
from trainlib.estimators import estimator_map
from trainlib.dataloaders import dataloader_map
def prepare_run() -> None:
# prepare cuda memory
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
print(f"CUDA allocated: {memory_allocated}GB")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def run(
estimator_name: str,
dataset_name: str,
dataloader_name: str,
estimator_kwargs: dict[str, Any] | None = None,
dataset_kwargs: dict[str, Any] | None = None,
dataset_split_fracs: dict[str, Any] | None = None,
dataset_split_kwargs: dict[str, Any] | None = None,
dataloader_kwargs: dict[str, Any] | None = None,
trainer_kwargs: dict[str, Any] | None = None,
train_kwargs: dict[str, Any] | None = None,
load_only: bool = False,
) -> Trainer | Estimator:
try:
estimator_cls = estimator_map[estimator_name]
except KeyError as err:
raise ValueError(
f"Invalid estimator name '{estimator_name}',"
f"must be one of {estimator_map.keys()}"
) from err
try:
dataset_cls = dataset_map[dataset_name]
except KeyError as err:
raise ValueError(
f"Invalid dataset name '{dataset_name}',"
f"must be one of {dataset_map.keys()}"
) from err
try:
dataloader_cls = dataloader_map[dataloader_name]
except KeyError as err:
raise ValueError(
f"Invalid dataloader name '{dataloader_name}',"
f"must be one of {dataloader_map.keys()}"
) from err
estimator_kwargs = estimator_kwargs or {}
dataset_kwargs = dataset_kwargs or {}
dataset_split_fracs = dataset_split_fracs or {}
dataset_split_kwargs = dataset_split_kwargs or {}
dataloader_kwargs = dataloader_kwargs or {}
trainer_kwargs = trainer_kwargs or {}
train_kwargs = train_kwargs or {}
default_estimator_kwargs = {}
default_dataset_kwargs = {}
default_dataset_split_kwargs = {}
default_dataset_split_fracs = {"train": 1.0, "val": 0.0, "aux": []}
default_dataloader_kwargs = {}
default_trainer_kwargs = {}
default_train_kwargs = {}
estimator_kwargs = {**default_estimator_kwargs, **estimator_kwargs}
dataset_kwargs = {**default_dataset_kwargs, **dataset_kwargs}
dataset_split_kwargs = {**default_dataset_split_kwargs, **dataset_split_kwargs}
dataset_split_fracs = {**default_dataset_split_fracs, **dataset_split_fracs}
dataloader_kwargs = {**default_dataloader_kwargs, **dataloader_kwargs}
trainer_kwargs = {**default_trainer_kwargs, **trainer_kwargs}
train_kwargs = {**default_train_kwargs, **train_kwargs}
estimator = estimator_cls(**estimator_kwargs)
dataset = dataset_cls(**dataset_kwargs)
train_dataset, val_dataset, *aux_datasets = dataset.split(
fracs=[
dataset_split_fracs["train"],
dataset_split_fracs["val"],
*dataset_split_fracs["aux"]
],
**dataset_split_kwargs
)
train_loader = dataloader_cls(train_dataset, **dataloader_kwargs)
val_loader = dataloader_cls(val_dataset, **dataloader_kwargs)
aux_loaders = [
dataloader_cls(aux_dataset, **dataloader_kwargs)
for aux_dataset in aux_datasets
]
trainer = Trainer(
estimator,
**trainer_kwargs,
)
if load_only:
return trainer
return trainer.train(
train_loader=train_loader,
val_loader=val_loader,
aux_loaders=aux_loaders,
**train_kwargs,
)
def run_from_json(
parameters_json: str | None = None,
parameters_file: str | None = None,
) -> Trainer | Estimator:
if not (parameters_json or parameters_file):
raise ValueError("parameter json or file required")
parameters: dict[str, Any]
if parameters_json:
try:
parameters = json.loads(parameters_json)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
elif parameters_file:
try:
with open(parameters_file, encoding="utf-8") as f:
parameters = json.load(f)
except FileNotFoundError as e:
raise ValueError(f"JSON file not found: {parameters_file}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
return run(**parameters)
def handle_train(args: argparse.Namespace) -> None:
run_from_json(args.parameters_json, args.parameters_file)
def register_parser(subparsers: _SubParsersAction) -> None:
parser = subparsers.add_parser("train", help="run training loop")
parser.add_argument(
"--parameters-json",
type=str,
help="Raw JSON string with train parameters",
)
parser.add_argument(
"--parameters-file",
type=str,
help="Path to JSON file with train parameters",
)
parser.set_defaults(func=handle_train)

View File

@@ -0,0 +1,12 @@
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.text import camel_to_snake
from trainlib.dataloaders.memory import SupervisedDataLoader
_dataloaders = [
SupervisedDataLoader,
]
dataloader_map: dict[str, type[EstimatorDataLoader]] = {
camel_to_snake(_dataloader.__name__): _dataloader
for _dataloader in _dataloaders
}

View File

@@ -0,0 +1,17 @@
from torch import Tensor
from trainlib.estimator import SupervisedKwargs
from trainlib.dataloader import EstimatorDataLoader
class SupervisedDataLoader(
EstimatorDataLoader[tuple[Tensor, Tensor], SupervisedKwargs]
):
def batch_to_est_kwargs(
self,
batch_data: tuple[Tensor, Tensor]
) -> SupervisedKwargs:
return SupervisedKwargs(
inputs=batch_data[0],
labels=batch_data[1],
)

View File

@@ -0,0 +1,12 @@
from trainlib.dataset import BatchedDataset
from trainlib.utils.text import camel_to_snake
from trainlib.datasets.memory import RandomXYDataset
_datasets = [
RandomXYDataset,
]
dataset_map: dict[str, type[BatchedDataset]] = {
camel_to_snake(_dataset.__name__): _dataset
for _dataset in _datasets
}

View File

@@ -73,13 +73,33 @@ class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
from typing import Unpack
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import TensorDataset
from trainlib.domain import SequenceDomain
from trainlib.dataset import TupleDataset, DatasetKwargs
class RandomXYDataset(TupleDataset[Tensor]):
def __init__(
self,
num_samples: int,
input_dim: int,
output_dim: int,
**kwargs: Unpack[DatasetKwargs],
) -> None:
domain = SequenceDomain[tuple[Tensor, Tensor]](
TensorDataset(
torch.randn((num_samples, input_dim)),
torch.randn((num_samples, output_dim))
),
)
super().__init__(domain, **kwargs)
class SlidingWindowDataset(TupleDataset[Tensor]):
def __init__(
self,
@@ -88,12 +108,26 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
offset: int = 0,
lookahead: int = 1,
num_windows: int = 1,
pad_mode: str = "constant",
#fill_with: str = "zero",
**kwargs: Unpack[DatasetKwargs],
) -> None:
"""
Parameters:
TODO: implement options for `fill_with`; currently just passing
through a `pad_mode` the Functional call, which does the job
fill_with: strategy to use for padding values in windows
- `zero`: fill with zeros
- `left`: use nearest window column (repeat leftmost)
- `mean`: fill with the window mean
"""
self.lookback = lookback
self.offset = offset
self.lookahead = lookahead
self.num_windows = num_windows
self.pad_mode = pad_mode
super().__init__(domain, **kwargs)
@@ -103,7 +137,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
batch_index: int,
) -> list[tuple[Tensor, ...]]:
"""
Backward pads first sequence over (lookback-1) length, and steps the
Backward pads window sequences over (lookback-1) length, and steps the
remaining items forward by the lookahead.
Batch data:
@@ -146,7 +180,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
exceeds the offset.
To get windows starting with the first index at the left: we first set
out window size (call it L), determined by `lookback`. Then the
our window size (call it L), determined by `lookback`. Then the
rightmost index we want will be `L-1`, which determines our `offset`
setting.
@@ -173,7 +207,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
# the amount of our offset, which in the extreme cases does no
# padding.
xip = F.pad(t, ((lb-1) - off, 0))
xip = F.pad(t, ((lb-1) - off, 0), mode=self.pad_mode)
# extract sliding windows over the padded tensor
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for

View File

@@ -35,6 +35,10 @@ class EstimatorKwargs(TypedDict):
inputs: Tensor
class SupervisedKwargs(EstimatorKwargs):
labels: Tensor
class Estimator[Kw: EstimatorKwargs](nn.Module):
"""
Estimator base class.

View File

@@ -0,0 +1,15 @@
from trainlib.estimator import Estimator
from trainlib.utils.text import camel_to_snake
from trainlib.estimators.mlp import MLP
from trainlib.estimators.rnn import LSTM, ConvGRU, MultiheadLSTM
_estimators: list[type[Estimator]] = [
MLP,
LSTM,
MultiheadLSTM,
ConvGRU,
]
estimator_map: dict[str, type[Estimator]] = {
camel_to_snake(_estimator.__name__): _estimator
for _estimator in _estimators
}

View File

@@ -8,6 +8,7 @@ from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from trainlib.utils import op
from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.type import OptimizerKwargs
from trainlib.utils.module import get_grad_norm
@@ -102,6 +103,7 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
labels = kwargs["labels"]
yield F.mse_loss(predictions, labels)
#yield F.l1_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad():
@@ -109,12 +111,16 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
mse = F.mse_loss(predictions, labels).item()
mae = F.l1_loss(predictions, labels).item()
r2 = op.r2_score(predictions, labels).item()
return {
# "loss": loss,
"mse": loss,
"loss": loss,
"mse": mse,
"mae": mae,
"r2": r2,
"grad_norm": get_grad_norm(self)
}
@@ -377,7 +383,8 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
# will be (B, T, C), applies indep at each time step across channels
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
# will be (B, C, T), applies indep at each time step across channels
# will be (B, Co, 1), applies indep at each channel across temporal dim
# size time steps
self.dense_z = TDNNLayer(
layer_in_dim,
self.output_dim,
@@ -438,6 +445,7 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
predictions = predictions.squeeze(-1)
yield F.mse_loss(predictions, labels, reduction="mean")
#yield F.l1_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad():
@@ -445,11 +453,16 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
predictions = self(**kwargs)[0].squeeze(-1)
labels = kwargs["labels"]
mse = F.mse_loss(predictions, labels).item()
mae = F.l1_loss(predictions, labels).item()
r2 = op.r2_score(predictions, labels).item()
return {
"mse": loss,
"loss": loss,
"mse": mse,
"mae": mae,
"r2": r2,
"grad_norm": get_grad_norm(self)
}

View File

@@ -64,14 +64,17 @@ class Plotter[Kw: EstimatorKwargs]:
for i, loader in enumerate(self.dataloaders):
label = self.dataloader_labels[i]
actual = torch.cat([
actual = [
self.kw_to_actual(batch_kwargs).detach().cpu()
for batch_kwargs in loader
])
output = torch.cat([
]
actual = torch.cat([ai.reshape(*([*ai.shape]+[1])[:2]) for ai in actual])
output = [
self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
for batch_kwargs in loader
])
]
output = torch.cat([oi.reshape(*([*oi.shape]+[1])[:2]) for oi in output])
data_tuples.append((actual, output, label))

View File

@@ -32,6 +32,7 @@ from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.map import nested_defaultdict
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.module import ModelWrapper
from trainlib.utils.session import ensure_same_device
logger: logging.Logger = logging.getLogger(__name__)
@@ -103,24 +104,39 @@ class Trainer[I, Kw: EstimatorKwargs]:
self.reset()
def reset(self) -> None:
def reset(self, resume: bool = False) -> None:
"""
Set initial tracking parameters for the primary training loop.
Parameters:
resume: if ``True``, just resets the stagnant epoch counter, with
the aims of continuing any existing training state under
resumed ``train()`` call. This should likely only be set when
training is continued on the same dataset and the goal is to
resume convergence loss-based scoring for a fresh set of
epochs. If even that element of the training loop should resume
(which should only happen if a training loop was interrupted or
a max epoch limit was reached), then this method shouldn't be
called at all between ``train()`` invocations.
"""
self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(list))
self._conv_loss = float("inf")
self._best_conv_loss = float("inf")
self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {}
if not resume:
self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(list))
self._conv_loss = float("inf")
self._best_conv_loss = float("inf")
self._best_conv_epoch = 0
self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch(
self,
loader: EstimatorDataLoader[Any, Kw],
optimizers: tuple[Optimizer, ...],
max_grad_norm: float | None = None,
progress_bar: tqdm | None = None,
) -> list[float]:
"""
Train the estimator for a single epoch.
@@ -128,32 +144,40 @@ class Trainer[I, Kw: EstimatorKwargs]:
loss_sums = []
self.estimator.train()
with tqdm(loader, unit="batch") as batches:
for i, batch_kwargs in enumerate(batches):
losses = self.estimator.loss(**batch_kwargs)
for i, batch_kwargs in enumerate(loader):
batch_kwargs = ensure_same_device(batch_kwargs, self.device)
losses = self.estimator.loss(**batch_kwargs)
for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True)
):
if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
loss_sums[o_idx] += loss.item()
for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True)
):
if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
loss_sums[o_idx] += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.zero_grad()
loss.backward()
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
clip_grad_norm_(
self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm
)
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
clip_grad_norm_(
self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm
)
optimizer.step()
optimizer.step()
# set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
# set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
if progress_bar:
progress_bar.update(1)
progress_bar.set_postfix(
epoch=self._epoch,
mode="opt",
data="train",
loss=f"{loss_avg:8.2f}",
)
# step estimator hyperparam schedules
self.estimator.epoch_step()
@@ -163,7 +187,8 @@ class Trainer[I, Kw: EstimatorKwargs]:
def _eval_epoch(
self,
loader: EstimatorDataLoader[Any, Kw],
label: str
label: str,
progress_bar: tqdm | None = None,
) -> list[float]:
"""
Perform and record validation scores for a single epoch.
@@ -191,45 +216,53 @@ class Trainer[I, Kw: EstimatorKwargs]:
loss_sums = []
self.estimator.eval()
with tqdm(loader, unit="batch") as batches:
for i, batch_kwargs in enumerate(batches):
losses = self.estimator.loss(**batch_kwargs)
for i, batch_kwargs in enumerate(loader):
batch_kwargs = ensure_same_device(batch_kwargs, self.device)
losses = self.estimator.loss(**batch_kwargs)
# one-time logging
if self._epoch == 0:
self._writer.add_graph(
ModelWrapper(self.estimator), batch_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=label,
**batch_kwargs
)
# once-per-session logging
if self._epoch == 0 and i == 0:
self._writer.add_graph(
ModelWrapper(self.estimator), batch_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=label,
**batch_kwargs
)
loss_items = []
for o_idx, loss in enumerate(losses):
if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
loss_items = []
for o_idx, loss in enumerate(losses):
if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
loss_item = loss.item()
loss_sums[o_idx] += loss_item
loss_items.append(loss_item)
loss_item = loss.item()
loss_sums[o_idx] += loss_item
loss_items.append(loss_item)
# set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
# set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
# log individual loss terms after each batch
for o_idx, loss_item in enumerate(loss_items):
self._log_event(label, f"loss_{o_idx}", loss_item)
if progress_bar:
progress_bar.update(1)
progress_bar.set_postfix(
epoch=self._epoch,
mode="eval",
data=label,
loss=f"{loss_avg:8.2f}",
)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**batch_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._log_event(label, metric_name, metric_value)
# log individual loss terms after each batch
for o_idx, loss_item in enumerate(loss_items):
self._log_event(label, f"loss_{o_idx}", loss_item)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**batch_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._log_event(label, metric_name, metric_value)
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
@@ -240,6 +273,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
train_loader: EstimatorDataLoader[Any, Kw],
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
progress_bar: tqdm | None = None,
) -> tuple[list[float], list[float] | None, *list[float]]:
"""
Evaluate estimator over each provided dataloader.
@@ -274,12 +308,15 @@ class Trainer[I, Kw: EstimatorKwargs]:
somewhere given the many possible design choices here.)
"""
train_loss = self._eval_epoch(train_loader, "train")
val_loss = self._eval_epoch(val_loader, "val") if val_loader else None
train_loss = self._eval_epoch(train_loader, "train", progress_bar)
val_loss = None
if val_loader is not None:
val_loss = self._eval_epoch(val_loader, "val", progress_bar)
aux_loaders = aux_loaders or []
aux_losses = [
self._eval_epoch(aux_loader, f"aux{i}")
self._eval_epoch(aux_loader, f"aux{i}", progress_bar)
for i, aux_loader in enumerate(aux_loaders)
]
@@ -433,27 +470,36 @@ class Trainer[I, Kw: EstimatorKwargs]:
self._session_name = session_name or str(int(time.time()))
tblog_path = Path(self.tblog_dir, self._session_name)
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
progress_bar = tqdm(train_loader, unit="batch")
# evaluate model on dataloaders once before training starts
self._eval_loaders(train_loader, val_loader, aux_loaders)
train_loss, val_loss, *_ = self._eval_loaders(
train_loader, val_loader, aux_loaders, progress_bar
)
conv_loss = val_loss if val_loss else train_loss
self._conv_loss = sum(conv_loss) / len(conv_loss)
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
while self._epoch < max_epochs and not self._converged(
self._epoch, stop_after_epochs
):
while self._epoch < max_epochs and not self._converged(stop_after_epochs):
self._epoch += 1
train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...")
print(f"Stagnant epochs {stag_frac}...")
#train_frac = f"{self._epoch}/{max_epochs}"
#stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
#print(f"Training epoch {train_frac}...")
#print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time()
self._train_epoch(train_loader, optimizers, max_grad_norm)
self._train_epoch(
train_loader,
optimizers,
max_grad_norm,
progress_bar=progress_bar
)
epoch_end_time = time.time() - epoch_start_time
self._log_event("train", "epoch_duration", epoch_end_time)
train_loss, val_loss, _ = self._eval_loaders(
train_loader, val_loader, aux_loaders
train_loss, val_loss, *_ = self._eval_loaders(
train_loader, val_loader, aux_loaders, progress_bar
)
# determine loss to use for measuring convergence
conv_loss = val_loss if val_loss else train_loss
@@ -466,12 +512,43 @@ class Trainer[I, Kw: EstimatorKwargs]:
return self.estimator
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
def _converged(self, stop_after_epochs: int) -> bool:
"""
Check if model has converged.
This method looks at the current "convergence loss" (validation-based
if a val set is provided to ``train()``, otherwise the training loss is
used), checking if it's the best yet recorded, incrementing the
stagnancy count if not. Convergence is asserted only if the number of
stagnant epochs exceeds ``stop_after_epochs``.
.. admonition:: Evaluation order
Convergence losses are recorded before the first training update,
so initial model states are appropriately benchmarked by the time
``_converged()`` is invoked.
If resuming training on the same dataset, one might expect only to
reset the stagnant epoch counter: you'll resume from the last
epoch, estimator state, and best seen loss, while allowed
``stop_after_epochs`` more chances for better validation.
If picking up training on a new dataset, even a training+validation
setting, resetting the best seen loss and best model state is
needed: you can't reliably compare the existing stats under new
data. It's somewhat ambiguous whether ``epoch`` absolutely must be
reset; you could continue logging metrics under the same named
session. But best practices would suggest restarting the epoch
count and have events logged under a new session heading when data
change.
"""
converged = False
if epoch == 0 or self._conv_loss < self._best_val_loss:
self._best_val_loss = self._conv_loss
if self._conv_loss < self._best_conv_loss:
self._stagnant_epochs = 0
self._best_conv_loss = self._conv_loss
self._best_conv_epoch = self._epoch
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
else:
self._stagnant_epochs += 1
@@ -491,6 +568,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
print(f"==== Epoch [{self._epoch}] summary ====")
for (group, name), epoch_map in self._summary.items():
for epoch, values in epoch_map.items():
# compute average over batch items recorded for the epoch
mean = torch.tensor(values).mean().item()
self._writer.add_scalar(f"{group}-{name}", mean, epoch)
if epoch == self._epoch:

9
trainlib/utils/op.py Normal file
View File

@@ -0,0 +1,9 @@
from torch import Tensor
def r2_score(y: Tensor, y_hat: Tensor) -> Tensor:
ss_res = ((y - y_hat)**2).sum()
ss_tot = ((y - y.mean())**2).sum()
r2 = 1 - ss_res / ss_tot
return r2

View File

@@ -3,6 +3,7 @@ import random
import numpy as np
import torch
from torch import Tensor
from torch.utils import _pytree as pytree
def seed_all_backends(seed: int | Tensor | None = None) -> None:
@@ -19,3 +20,9 @@ def seed_all_backends(seed: int | Tensor | None = None) -> None:
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def ensure_same_device[T](tree: T, device: str) -> T:
return pytree.tree_map(
lambda x: x.to(device) if isinstance(x, torch.Tensor) else x,
tree,
)

View File

@@ -1,8 +1,14 @@
import re
from typing import Any
from colorama import Style
camel2snake_regex: re.Pattern[str] = re.compile(
r"(?<!^)(?=[A-Z][a-z])|(?<=[a-z])(?=[A-Z])"
)
def camel_to_snake(text: str) -> str:
return camel2snake_regex.sub("_", text).lower()
def color_text(text: str, *colorama_args: Any) -> str:
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"

View File

@@ -61,8 +61,8 @@ class SubplotsKwargs(TypedDict, total=False):
squeeze: bool
width_ratios: Sequence[float]
height_ratios: Sequence[float]
subplot_kw: dict[str, ...]
gridspec_kw: dict[str, ...]
subplot_kw: dict[str, Any]
gridspec_kw: dict[str, Any]
figsize: tuple[float, float]
dpi: float
layout: str

2
uv.lock generated
View File

@@ -1659,7 +1659,7 @@ wheels = [
[[package]]
name = "trainlib"
version = "0.1.2"
version = "0.3.1"
source = { editable = "." }
dependencies = [
{ name = "colorama" },