2 Commits
0.2.0 ... 0.3.1

Author SHA1 Message Date
fdccb4c5eb update train loop eval logic 2026-03-31 22:52:27 -07:00
ba0c804d5e implement CLI trainer utility, update progress bar logging 2026-03-25 02:25:28 -07:00
19 changed files with 543 additions and 96 deletions

29
example/example.json Normal file
View File

@@ -0,0 +1,29 @@
{
"estimator_name": "mlp",
"dataset_name": "random_xy_dataset",
"dataloader_name": "supervised_data_loader",
"estimator_kwargs": {
"input_dim": 4,
"output_dim": 2
},
"dataset_kwargs": {
"num_samples": 100000,
"preload": true,
"input_dim": 4,
"output_dim": 2
},
"dataset_split_fracs": {
"train": 0.4,
"val": 0.3,
"aux": [0.3]
},
"dataloader_kwargs": {
"batch_size": 16
},
"train_kwargs": {
"summarize_every": 20,
"max_epochs": 100,
"stop_after_epochs": 100
},
"load_only": false
}

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "trainlib" name = "trainlib"
version = "0.2.0" version = "0.3.1"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13" requires-python = ">=3.13"
authors = [ authors = [

18
trainlib/__main__.py Normal file
View File

@@ -0,0 +1,18 @@
import logging
from trainlib.cli import create_parser
def main() -> None:
parser = create_parser()
args = parser.parse_args()
# skim off log level to handle higher-level option
if hasattr(args, "log_level") and args.log_level is not None:
logging.basicConfig(level=args.log_level)
args.func(args) if "func" in args else parser.print_help()
if __name__ == "__main__":
main()

26
trainlib/cli/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
import logging
import argparse
from trainlib.cli import train
logger: logging.Logger = logging.getLogger(__name__)
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="trainlib cli",
# formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--log-level",
type=int,
metavar="int",
choices=[10, 20, 30, 40, 50],
help="Log level: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL",
)
subparsers = parser.add_subparsers(help="subcommand help")
train.register_parser(subparsers)
return parser

164
trainlib/cli/train.py Normal file
View File

@@ -0,0 +1,164 @@
import gc
import json
import argparse
from typing import Any
from argparse import _SubParsersAction
import torch
from trainlib.trainer import Trainer
from trainlib.datasets import dataset_map
from trainlib.estimator import Estimator
from trainlib.estimators import estimator_map
from trainlib.dataloaders import dataloader_map
def prepare_run() -> None:
# prepare cuda memory
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
print(f"CUDA allocated: {memory_allocated}GB")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def run(
estimator_name: str,
dataset_name: str,
dataloader_name: str,
estimator_kwargs: dict[str, Any] | None = None,
dataset_kwargs: dict[str, Any] | None = None,
dataset_split_fracs: dict[str, Any] | None = None,
dataset_split_kwargs: dict[str, Any] | None = None,
dataloader_kwargs: dict[str, Any] | None = None,
trainer_kwargs: dict[str, Any] | None = None,
train_kwargs: dict[str, Any] | None = None,
load_only: bool = False,
) -> Trainer | Estimator:
try:
estimator_cls = estimator_map[estimator_name]
except KeyError as err:
raise ValueError(
f"Invalid estimator name '{estimator_name}',"
f"must be one of {estimator_map.keys()}"
) from err
try:
dataset_cls = dataset_map[dataset_name]
except KeyError as err:
raise ValueError(
f"Invalid dataset name '{dataset_name}',"
f"must be one of {dataset_map.keys()}"
) from err
try:
dataloader_cls = dataloader_map[dataloader_name]
except KeyError as err:
raise ValueError(
f"Invalid dataloader name '{dataloader_name}',"
f"must be one of {dataloader_map.keys()}"
) from err
estimator_kwargs = estimator_kwargs or {}
dataset_kwargs = dataset_kwargs or {}
dataset_split_fracs = dataset_split_fracs or {}
dataset_split_kwargs = dataset_split_kwargs or {}
dataloader_kwargs = dataloader_kwargs or {}
trainer_kwargs = trainer_kwargs or {}
train_kwargs = train_kwargs or {}
default_estimator_kwargs = {}
default_dataset_kwargs = {}
default_dataset_split_kwargs = {}
default_dataset_split_fracs = {"train": 1.0, "val": 0.0, "aux": []}
default_dataloader_kwargs = {}
default_trainer_kwargs = {}
default_train_kwargs = {}
estimator_kwargs = {**default_estimator_kwargs, **estimator_kwargs}
dataset_kwargs = {**default_dataset_kwargs, **dataset_kwargs}
dataset_split_kwargs = {**default_dataset_split_kwargs, **dataset_split_kwargs}
dataset_split_fracs = {**default_dataset_split_fracs, **dataset_split_fracs}
dataloader_kwargs = {**default_dataloader_kwargs, **dataloader_kwargs}
trainer_kwargs = {**default_trainer_kwargs, **trainer_kwargs}
train_kwargs = {**default_train_kwargs, **train_kwargs}
estimator = estimator_cls(**estimator_kwargs)
dataset = dataset_cls(**dataset_kwargs)
train_dataset, val_dataset, *aux_datasets = dataset.split(
fracs=[
dataset_split_fracs["train"],
dataset_split_fracs["val"],
*dataset_split_fracs["aux"]
],
**dataset_split_kwargs
)
train_loader = dataloader_cls(train_dataset, **dataloader_kwargs)
val_loader = dataloader_cls(val_dataset, **dataloader_kwargs)
aux_loaders = [
dataloader_cls(aux_dataset, **dataloader_kwargs)
for aux_dataset in aux_datasets
]
trainer = Trainer(
estimator,
**trainer_kwargs,
)
if load_only:
return trainer
return trainer.train(
train_loader=train_loader,
val_loader=val_loader,
aux_loaders=aux_loaders,
**train_kwargs,
)
def run_from_json(
parameters_json: str | None = None,
parameters_file: str | None = None,
) -> Trainer | Estimator:
if not (parameters_json or parameters_file):
raise ValueError("parameter json or file required")
parameters: dict[str, Any]
if parameters_json:
try:
parameters = json.loads(parameters_json)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
elif parameters_file:
try:
with open(parameters_file, encoding="utf-8") as f:
parameters = json.load(f)
except FileNotFoundError as e:
raise ValueError(f"JSON file not found: {parameters_file}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
return run(**parameters)
def handle_train(args: argparse.Namespace) -> None:
run_from_json(args.parameters_json, args.parameters_file)
def register_parser(subparsers: _SubParsersAction) -> None:
parser = subparsers.add_parser("train", help="run training loop")
parser.add_argument(
"--parameters-json",
type=str,
help="Raw JSON string with train parameters",
)
parser.add_argument(
"--parameters-file",
type=str,
help="Path to JSON file with train parameters",
)
parser.set_defaults(func=handle_train)

View File

@@ -0,0 +1,12 @@
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.text import camel_to_snake
from trainlib.dataloaders.memory import SupervisedDataLoader
_dataloaders = [
SupervisedDataLoader,
]
dataloader_map: dict[str, type[EstimatorDataLoader]] = {
camel_to_snake(_dataloader.__name__): _dataloader
for _dataloader in _dataloaders
}

View File

@@ -0,0 +1,17 @@
from torch import Tensor
from trainlib.estimator import SupervisedKwargs
from trainlib.dataloader import EstimatorDataLoader
class SupervisedDataLoader(
EstimatorDataLoader[tuple[Tensor, Tensor], SupervisedKwargs]
):
def batch_to_est_kwargs(
self,
batch_data: tuple[Tensor, Tensor]
) -> SupervisedKwargs:
return SupervisedKwargs(
inputs=batch_data[0],
labels=batch_data[1],
)

View File

@@ -0,0 +1,12 @@
from trainlib.dataset import BatchedDataset
from trainlib.utils.text import camel_to_snake
from trainlib.datasets.memory import RandomXYDataset
_datasets = [
RandomXYDataset,
]
dataset_map: dict[str, type[BatchedDataset]] = {
camel_to_snake(_dataset.__name__): _dataset
for _dataset in _datasets
}

View File

@@ -73,13 +73,33 @@ class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
from typing import Unpack from typing import Unpack
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.utils.data import TensorDataset
from trainlib.domain import SequenceDomain from trainlib.domain import SequenceDomain
from trainlib.dataset import TupleDataset, DatasetKwargs from trainlib.dataset import TupleDataset, DatasetKwargs
class RandomXYDataset(TupleDataset[Tensor]):
def __init__(
self,
num_samples: int,
input_dim: int,
output_dim: int,
**kwargs: Unpack[DatasetKwargs],
) -> None:
domain = SequenceDomain[tuple[Tensor, Tensor]](
TensorDataset(
torch.randn((num_samples, input_dim)),
torch.randn((num_samples, output_dim))
),
)
super().__init__(domain, **kwargs)
class SlidingWindowDataset(TupleDataset[Tensor]): class SlidingWindowDataset(TupleDataset[Tensor]):
def __init__( def __init__(
self, self,
@@ -88,12 +108,26 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
offset: int = 0, offset: int = 0,
lookahead: int = 1, lookahead: int = 1,
num_windows: int = 1, num_windows: int = 1,
pad_mode: str = "constant",
#fill_with: str = "zero",
**kwargs: Unpack[DatasetKwargs], **kwargs: Unpack[DatasetKwargs],
) -> None: ) -> None:
"""
Parameters:
TODO: implement options for `fill_with`; currently just passing
through a `pad_mode` the Functional call, which does the job
fill_with: strategy to use for padding values in windows
- `zero`: fill with zeros
- `left`: use nearest window column (repeat leftmost)
- `mean`: fill with the window mean
"""
self.lookback = lookback self.lookback = lookback
self.offset = offset self.offset = offset
self.lookahead = lookahead self.lookahead = lookahead
self.num_windows = num_windows self.num_windows = num_windows
self.pad_mode = pad_mode
super().__init__(domain, **kwargs) super().__init__(domain, **kwargs)
@@ -103,7 +137,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
batch_index: int, batch_index: int,
) -> list[tuple[Tensor, ...]]: ) -> list[tuple[Tensor, ...]]:
""" """
Backward pads first sequence over (lookback-1) length, and steps the Backward pads window sequences over (lookback-1) length, and steps the
remaining items forward by the lookahead. remaining items forward by the lookahead.
Batch data: Batch data:
@@ -146,7 +180,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
exceeds the offset. exceeds the offset.
To get windows starting with the first index at the left: we first set To get windows starting with the first index at the left: we first set
out window size (call it L), determined by `lookback`. Then the our window size (call it L), determined by `lookback`. Then the
rightmost index we want will be `L-1`, which determines our `offset` rightmost index we want will be `L-1`, which determines our `offset`
setting. setting.
@@ -173,7 +207,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
# for window sized `lb`, we pad with `lb-1` zeros. We then take off # for window sized `lb`, we pad with `lb-1` zeros. We then take off
# the amount of our offset, which in the extreme cases does no # the amount of our offset, which in the extreme cases does no
# padding. # padding.
xip = F.pad(t, ((lb-1) - off, 0)) xip = F.pad(t, ((lb-1) - off, 0), mode=self.pad_mode)
# extract sliding windows over the padded tensor # extract sliding windows over the padded tensor
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for # unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for

View File

@@ -35,6 +35,10 @@ class EstimatorKwargs(TypedDict):
inputs: Tensor inputs: Tensor
class SupervisedKwargs(EstimatorKwargs):
labels: Tensor
class Estimator[Kw: EstimatorKwargs](nn.Module): class Estimator[Kw: EstimatorKwargs](nn.Module):
""" """
Estimator base class. Estimator base class.

View File

@@ -0,0 +1,15 @@
from trainlib.estimator import Estimator
from trainlib.utils.text import camel_to_snake
from trainlib.estimators.mlp import MLP
from trainlib.estimators.rnn import LSTM, ConvGRU, MultiheadLSTM
_estimators: list[type[Estimator]] = [
MLP,
LSTM,
MultiheadLSTM,
ConvGRU,
]
estimator_map: dict[str, type[Estimator]] = {
camel_to_snake(_estimator.__name__): _estimator
for _estimator in _estimators
}

View File

@@ -8,6 +8,7 @@ from torch import nn, Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from trainlib.utils import op
from trainlib.estimator import Estimator, EstimatorKwargs from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.type import OptimizerKwargs from trainlib.utils.type import OptimizerKwargs
from trainlib.utils.module import get_grad_norm from trainlib.utils.module import get_grad_norm
@@ -102,6 +103,7 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
labels = kwargs["labels"] labels = kwargs["labels"]
yield F.mse_loss(predictions, labels) yield F.mse_loss(predictions, labels)
#yield F.l1_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad(): with torch.no_grad():
@@ -109,12 +111,16 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
predictions = self(**kwargs)[0] predictions = self(**kwargs)[0]
labels = kwargs["labels"] labels = kwargs["labels"]
mse = F.mse_loss(predictions, labels).item()
mae = F.l1_loss(predictions, labels).item() mae = F.l1_loss(predictions, labels).item()
r2 = op.r2_score(predictions, labels).item()
return { return {
# "loss": loss, "loss": loss,
"mse": loss, "mse": mse,
"mae": mae, "mae": mae,
"r2": r2,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }
@@ -377,7 +383,8 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
# will be (B, T, C), applies indep at each time step across channels # will be (B, T, C), applies indep at each time step across channels
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim) # self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
# will be (B, C, T), applies indep at each time step across channels # will be (B, Co, 1), applies indep at each channel across temporal dim
# size time steps
self.dense_z = TDNNLayer( self.dense_z = TDNNLayer(
layer_in_dim, layer_in_dim,
self.output_dim, self.output_dim,
@@ -438,6 +445,7 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
predictions = predictions.squeeze(-1) predictions = predictions.squeeze(-1)
yield F.mse_loss(predictions, labels, reduction="mean") yield F.mse_loss(predictions, labels, reduction="mean")
#yield F.l1_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]: def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad(): with torch.no_grad():
@@ -445,11 +453,16 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
predictions = self(**kwargs)[0].squeeze(-1) predictions = self(**kwargs)[0].squeeze(-1)
labels = kwargs["labels"] labels = kwargs["labels"]
mse = F.mse_loss(predictions, labels).item()
mae = F.l1_loss(predictions, labels).item() mae = F.l1_loss(predictions, labels).item()
r2 = op.r2_score(predictions, labels).item()
return { return {
"mse": loss, "loss": loss,
"mse": mse,
"mae": mae, "mae": mae,
"r2": r2,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }

View File

@@ -64,14 +64,17 @@ class Plotter[Kw: EstimatorKwargs]:
for i, loader in enumerate(self.dataloaders): for i, loader in enumerate(self.dataloaders):
label = self.dataloader_labels[i] label = self.dataloader_labels[i]
actual = torch.cat([ actual = [
self.kw_to_actual(batch_kwargs).detach().cpu() self.kw_to_actual(batch_kwargs).detach().cpu()
for batch_kwargs in loader for batch_kwargs in loader
]) ]
output = torch.cat([ actual = torch.cat([ai.reshape(*([*ai.shape]+[1])[:2]) for ai in actual])
output = [
self.trainer.estimator(**batch_kwargs)[0].detach().cpu() self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
for batch_kwargs in loader for batch_kwargs in loader
]) ]
output = torch.cat([oi.reshape(*([*oi.shape]+[1])[:2]) for oi in output])
data_tuples.append((actual, output, label)) data_tuples.append((actual, output, label))

View File

@@ -32,6 +32,7 @@ from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.map import nested_defaultdict from trainlib.utils.map import nested_defaultdict
from trainlib.dataloader import EstimatorDataLoader from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.module import ModelWrapper from trainlib.utils.module import ModelWrapper
from trainlib.utils.session import ensure_same_device
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@@ -103,24 +104,39 @@ class Trainer[I, Kw: EstimatorKwargs]:
self.reset() self.reset()
def reset(self) -> None: def reset(self, resume: bool = False) -> None:
""" """
Set initial tracking parameters for the primary training loop. Set initial tracking parameters for the primary training loop.
Parameters:
resume: if ``True``, just resets the stagnant epoch counter, with
the aims of continuing any existing training state under
resumed ``train()`` call. This should likely only be set when
training is continued on the same dataset and the goal is to
resume convergence loss-based scoring for a fresh set of
epochs. If even that element of the training loop should resume
(which should only happen if a training loop was interrupted or
a max epoch limit was reached), then this method shouldn't be
called at all between ``train()`` invocations.
""" """
self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(list))
self._conv_loss = float("inf")
self._best_conv_loss = float("inf")
self._stagnant_epochs = 0 self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {}
if not resume:
self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(list))
self._conv_loss = float("inf")
self._best_conv_loss = float("inf")
self._best_conv_epoch = 0
self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch( def _train_epoch(
self, self,
loader: EstimatorDataLoader[Any, Kw], loader: EstimatorDataLoader[Any, Kw],
optimizers: tuple[Optimizer, ...], optimizers: tuple[Optimizer, ...],
max_grad_norm: float | None = None, max_grad_norm: float | None = None,
progress_bar: tqdm | None = None,
) -> list[float]: ) -> list[float]:
""" """
Train the estimator for a single epoch. Train the estimator for a single epoch.
@@ -128,32 +144,40 @@ class Trainer[I, Kw: EstimatorKwargs]:
loss_sums = [] loss_sums = []
self.estimator.train() self.estimator.train()
with tqdm(loader, unit="batch") as batches: for i, batch_kwargs in enumerate(loader):
for i, batch_kwargs in enumerate(batches): batch_kwargs = ensure_same_device(batch_kwargs, self.device)
losses = self.estimator.loss(**batch_kwargs) losses = self.estimator.loss(**batch_kwargs)
for o_idx, (loss, optimizer) in enumerate( for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True) zip(losses, optimizers, strict=True)
): ):
if len(loss_sums) <= o_idx: if len(loss_sums) <= o_idx:
loss_sums.append(0.0) loss_sums.append(0.0)
loss_sums[o_idx] += loss.item() loss_sums[o_idx] += loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
# clip gradients for optimizer's parameters # clip gradients for optimizer's parameters
if max_grad_norm is not None: if max_grad_norm is not None:
clip_grad_norm_( clip_grad_norm_(
self._get_optimizer_parameters(optimizer), self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm max_norm=max_grad_norm
) )
optimizer.step() optimizer.step()
# set loop loss to running average (reducing if multi-loss) # set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1)) loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
if progress_bar:
progress_bar.update(1)
progress_bar.set_postfix(
epoch=self._epoch,
mode="opt",
data="train",
loss=f"{loss_avg:8.2f}",
)
# step estimator hyperparam schedules # step estimator hyperparam schedules
self.estimator.epoch_step() self.estimator.epoch_step()
@@ -163,7 +187,8 @@ class Trainer[I, Kw: EstimatorKwargs]:
def _eval_epoch( def _eval_epoch(
self, self,
loader: EstimatorDataLoader[Any, Kw], loader: EstimatorDataLoader[Any, Kw],
label: str label: str,
progress_bar: tqdm | None = None,
) -> list[float]: ) -> list[float]:
""" """
Perform and record validation scores for a single epoch. Perform and record validation scores for a single epoch.
@@ -191,45 +216,53 @@ class Trainer[I, Kw: EstimatorKwargs]:
loss_sums = [] loss_sums = []
self.estimator.eval() self.estimator.eval()
with tqdm(loader, unit="batch") as batches: for i, batch_kwargs in enumerate(loader):
for i, batch_kwargs in enumerate(batches): batch_kwargs = ensure_same_device(batch_kwargs, self.device)
losses = self.estimator.loss(**batch_kwargs) losses = self.estimator.loss(**batch_kwargs)
# one-time logging # once-per-session logging
if self._epoch == 0: if self._epoch == 0 and i == 0:
self._writer.add_graph( self._writer.add_graph(
ModelWrapper(self.estimator), batch_kwargs ModelWrapper(self.estimator), batch_kwargs
) )
# once-per-epoch logging # once-per-epoch logging
if i == 0: if i == 0:
self.estimator.epoch_write( self.estimator.epoch_write(
self._writer, self._writer,
step=self._epoch, step=self._epoch,
group=label, group=label,
**batch_kwargs **batch_kwargs
) )
loss_items = [] loss_items = []
for o_idx, loss in enumerate(losses): for o_idx, loss in enumerate(losses):
if len(loss_sums) <= o_idx: if len(loss_sums) <= o_idx:
loss_sums.append(0.0) loss_sums.append(0.0)
loss_item = loss.item() loss_item = loss.item()
loss_sums[o_idx] += loss_item loss_sums[o_idx] += loss_item
loss_items.append(loss_item) loss_items.append(loss_item)
# set loop loss to running average (reducing if multi-loss) # set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1)) loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
# log individual loss terms after each batch if progress_bar:
for o_idx, loss_item in enumerate(loss_items): progress_bar.update(1)
self._log_event(label, f"loss_{o_idx}", loss_item) progress_bar.set_postfix(
epoch=self._epoch,
mode="eval",
data=label,
loss=f"{loss_avg:8.2f}",
)
# log metrics for batch # log individual loss terms after each batch
estimator_metrics = self.estimator.metrics(**batch_kwargs) for o_idx, loss_item in enumerate(loss_items):
for metric_name, metric_value in estimator_metrics.items(): self._log_event(label, f"loss_{o_idx}", loss_item)
self._log_event(label, metric_name, metric_value)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**batch_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._log_event(label, metric_name, metric_value)
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums] avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
@@ -240,6 +273,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
train_loader: EstimatorDataLoader[Any, Kw], train_loader: EstimatorDataLoader[Any, Kw],
val_loader: EstimatorDataLoader[Any, Kw] | None = None, val_loader: EstimatorDataLoader[Any, Kw] | None = None,
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None, aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
progress_bar: tqdm | None = None,
) -> tuple[list[float], list[float] | None, *list[float]]: ) -> tuple[list[float], list[float] | None, *list[float]]:
""" """
Evaluate estimator over each provided dataloader. Evaluate estimator over each provided dataloader.
@@ -274,12 +308,15 @@ class Trainer[I, Kw: EstimatorKwargs]:
somewhere given the many possible design choices here.) somewhere given the many possible design choices here.)
""" """
train_loss = self._eval_epoch(train_loader, "train") train_loss = self._eval_epoch(train_loader, "train", progress_bar)
val_loss = self._eval_epoch(val_loader, "val") if val_loader else None
val_loss = None
if val_loader is not None:
val_loss = self._eval_epoch(val_loader, "val", progress_bar)
aux_loaders = aux_loaders or [] aux_loaders = aux_loaders or []
aux_losses = [ aux_losses = [
self._eval_epoch(aux_loader, f"aux{i}") self._eval_epoch(aux_loader, f"aux{i}", progress_bar)
for i, aux_loader in enumerate(aux_loaders) for i, aux_loader in enumerate(aux_loaders)
] ]
@@ -433,27 +470,36 @@ class Trainer[I, Kw: EstimatorKwargs]:
self._session_name = session_name or str(int(time.time())) self._session_name = session_name or str(int(time.time()))
tblog_path = Path(self.tblog_dir, self._session_name) tblog_path = Path(self.tblog_dir, self._session_name)
self._writer = summary_writer or SummaryWriter(f"{tblog_path}") self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
progress_bar = tqdm(train_loader, unit="batch")
# evaluate model on dataloaders once before training starts # evaluate model on dataloaders once before training starts
self._eval_loaders(train_loader, val_loader, aux_loaders) train_loss, val_loss, *_ = self._eval_loaders(
train_loader, val_loader, aux_loaders, progress_bar
)
conv_loss = val_loss if val_loss else train_loss
self._conv_loss = sum(conv_loss) / len(conv_loss)
optimizers = self.estimator.optimizers(lr=lr, eps=eps) optimizers = self.estimator.optimizers(lr=lr, eps=eps)
while self._epoch < max_epochs and not self._converged( while self._epoch < max_epochs and not self._converged(stop_after_epochs):
self._epoch, stop_after_epochs
):
self._epoch += 1 self._epoch += 1
train_frac = f"{self._epoch}/{max_epochs}" #train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" #stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...") #print(f"Training epoch {train_frac}...")
print(f"Stagnant epochs {stag_frac}...") #print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time() epoch_start_time = time.time()
self._train_epoch(train_loader, optimizers, max_grad_norm) self._train_epoch(
train_loader,
optimizers,
max_grad_norm,
progress_bar=progress_bar
)
epoch_end_time = time.time() - epoch_start_time epoch_end_time = time.time() - epoch_start_time
self._log_event("train", "epoch_duration", epoch_end_time) self._log_event("train", "epoch_duration", epoch_end_time)
train_loss, val_loss, _ = self._eval_loaders( train_loss, val_loss, *_ = self._eval_loaders(
train_loader, val_loader, aux_loaders train_loader, val_loader, aux_loaders, progress_bar
) )
# determine loss to use for measuring convergence # determine loss to use for measuring convergence
conv_loss = val_loss if val_loss else train_loss conv_loss = val_loss if val_loss else train_loss
@@ -466,12 +512,43 @@ class Trainer[I, Kw: EstimatorKwargs]:
return self.estimator return self.estimator
def _converged(self, epoch: int, stop_after_epochs: int) -> bool: def _converged(self, stop_after_epochs: int) -> bool:
"""
Check if model has converged.
This method looks at the current "convergence loss" (validation-based
if a val set is provided to ``train()``, otherwise the training loss is
used), checking if it's the best yet recorded, incrementing the
stagnancy count if not. Convergence is asserted only if the number of
stagnant epochs exceeds ``stop_after_epochs``.
.. admonition:: Evaluation order
Convergence losses are recorded before the first training update,
so initial model states are appropriately benchmarked by the time
``_converged()`` is invoked.
If resuming training on the same dataset, one might expect only to
reset the stagnant epoch counter: you'll resume from the last
epoch, estimator state, and best seen loss, while allowed
``stop_after_epochs`` more chances for better validation.
If picking up training on a new dataset, even a training+validation
setting, resetting the best seen loss and best model state is
needed: you can't reliably compare the existing stats under new
data. It's somewhat ambiguous whether ``epoch`` absolutely must be
reset; you could continue logging metrics under the same named
session. But best practices would suggest restarting the epoch
count and have events logged under a new session heading when data
change.
"""
converged = False converged = False
if epoch == 0 or self._conv_loss < self._best_val_loss: if self._conv_loss < self._best_conv_loss:
self._best_val_loss = self._conv_loss
self._stagnant_epochs = 0 self._stagnant_epochs = 0
self._best_conv_loss = self._conv_loss
self._best_conv_epoch = self._epoch
self._best_model_state_dict = deepcopy(self.estimator.state_dict()) self._best_model_state_dict = deepcopy(self.estimator.state_dict())
else: else:
self._stagnant_epochs += 1 self._stagnant_epochs += 1
@@ -491,6 +568,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
print(f"==== Epoch [{self._epoch}] summary ====") print(f"==== Epoch [{self._epoch}] summary ====")
for (group, name), epoch_map in self._summary.items(): for (group, name), epoch_map in self._summary.items():
for epoch, values in epoch_map.items(): for epoch, values in epoch_map.items():
# compute average over batch items recorded for the epoch
mean = torch.tensor(values).mean().item() mean = torch.tensor(values).mean().item()
self._writer.add_scalar(f"{group}-{name}", mean, epoch) self._writer.add_scalar(f"{group}-{name}", mean, epoch)
if epoch == self._epoch: if epoch == self._epoch:

9
trainlib/utils/op.py Normal file
View File

@@ -0,0 +1,9 @@
from torch import Tensor
def r2_score(y: Tensor, y_hat: Tensor) -> Tensor:
ss_res = ((y - y_hat)**2).sum()
ss_tot = ((y - y.mean())**2).sum()
r2 = 1 - ss_res / ss_tot
return r2

View File

@@ -3,6 +3,7 @@ import random
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils import _pytree as pytree
def seed_all_backends(seed: int | Tensor | None = None) -> None: def seed_all_backends(seed: int | Tensor | None = None) -> None:
@@ -19,3 +20,9 @@ def seed_all_backends(seed: int | Tensor | None = None) -> None:
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
def ensure_same_device[T](tree: T, device: str) -> T:
return pytree.tree_map(
lambda x: x.to(device) if isinstance(x, torch.Tensor) else x,
tree,
)

View File

@@ -1,8 +1,14 @@
import re
from typing import Any from typing import Any
from colorama import Style from colorama import Style
camel2snake_regex: re.Pattern[str] = re.compile(
r"(?<!^)(?=[A-Z][a-z])|(?<=[a-z])(?=[A-Z])"
)
def camel_to_snake(text: str) -> str:
return camel2snake_regex.sub("_", text).lower()
def color_text(text: str, *colorama_args: Any) -> str: def color_text(text: str, *colorama_args: Any) -> str:
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}" return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"

View File

@@ -61,8 +61,8 @@ class SubplotsKwargs(TypedDict, total=False):
squeeze: bool squeeze: bool
width_ratios: Sequence[float] width_ratios: Sequence[float]
height_ratios: Sequence[float] height_ratios: Sequence[float]
subplot_kw: dict[str, ...] subplot_kw: dict[str, Any]
gridspec_kw: dict[str, ...] gridspec_kw: dict[str, Any]
figsize: tuple[float, float] figsize: tuple[float, float]
dpi: float dpi: float
layout: str layout: str

2
uv.lock generated
View File

@@ -1659,7 +1659,7 @@ wheels = [
[[package]] [[package]]
name = "trainlib" name = "trainlib"
version = "0.1.2" version = "0.3.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "colorama" }, { name = "colorama" },