1 Commits
0.2.0 ... 0.3.0

Author SHA1 Message Date
ba0c804d5e implement CLI trainer utility, update progress bar logging 2026-03-25 02:25:28 -07:00
16 changed files with 430 additions and 70 deletions

29
example/example.json Normal file
View File

@@ -0,0 +1,29 @@
{
"estimator_name": "mlp",
"dataset_name": "random_xy_dataset",
"dataloader_name": "supervised_data_loader",
"estimator_kwargs": {
"input_dim": 4,
"output_dim": 2
},
"dataset_kwargs": {
"num_samples": 100000,
"preload": true,
"input_dim": 4,
"output_dim": 2
},
"dataset_split_fracs": {
"train": 0.4,
"val": 0.3,
"aux": [0.3]
},
"dataloader_kwargs": {
"batch_size": 16
},
"train_kwargs": {
"summarize_every": 20,
"max_epochs": 100,
"stop_after_epochs": 100
},
"load_only": false
}

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "trainlib" name = "trainlib"
version = "0.2.0" version = "0.3.0"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13" requires-python = ">=3.13"
authors = [ authors = [

18
trainlib/__main__.py Normal file
View File

@@ -0,0 +1,18 @@
import logging
from trainlib.cli import create_parser
def main() -> None:
parser = create_parser()
args = parser.parse_args()
# skim off log level to handle higher-level option
if hasattr(args, "log_level") and args.log_level is not None:
logging.basicConfig(level=args.log_level)
args.func(args) if "func" in args else parser.print_help()
if __name__ == "__main__":
main()

26
trainlib/cli/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
import logging
import argparse
from trainlib.cli import train
logger: logging.Logger = logging.getLogger(__name__)
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="trainlib cli",
# formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--log-level",
type=int,
metavar="int",
choices=[10, 20, 30, 40, 50],
help="Log level: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL",
)
subparsers = parser.add_subparsers(help="subcommand help")
train.register_parser(subparsers)
return parser

164
trainlib/cli/train.py Normal file
View File

@@ -0,0 +1,164 @@
import gc
import json
import argparse
from typing import Any
from argparse import _SubParsersAction
import torch
from trainlib.trainer import Trainer
from trainlib.datasets import dataset_map
from trainlib.estimator import Estimator
from trainlib.estimators import estimator_map
from trainlib.dataloaders import dataloader_map
def prepare_run() -> None:
# prepare cuda memory
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
print(f"CUDA allocated: {memory_allocated}GB")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def run(
estimator_name: str,
dataset_name: str,
dataloader_name: str,
estimator_kwargs: dict[str, Any] | None = None,
dataset_kwargs: dict[str, Any] | None = None,
dataset_split_fracs: dict[str, Any] | None = None,
dataset_split_kwargs: dict[str, Any] | None = None,
dataloader_kwargs: dict[str, Any] | None = None,
trainer_kwargs: dict[str, Any] | None = None,
train_kwargs: dict[str, Any] | None = None,
load_only: bool = False,
) -> Trainer | Estimator:
try:
estimator_cls = estimator_map[estimator_name]
except KeyError as err:
raise ValueError(
f"Invalid estimator name '{estimator_name}',"
f"must be one of {estimator_map.keys()}"
) from err
try:
dataset_cls = dataset_map[dataset_name]
except KeyError as err:
raise ValueError(
f"Invalid dataset name '{dataset_name}',"
f"must be one of {dataset_map.keys()}"
) from err
try:
dataloader_cls = dataloader_map[dataloader_name]
except KeyError as err:
raise ValueError(
f"Invalid dataloader name '{dataloader_name}',"
f"must be one of {dataloader_map.keys()}"
) from err
estimator_kwargs = estimator_kwargs or {}
dataset_kwargs = dataset_kwargs or {}
dataset_split_fracs = dataset_split_fracs or {}
dataset_split_kwargs = dataset_split_kwargs or {}
dataloader_kwargs = dataloader_kwargs or {}
trainer_kwargs = trainer_kwargs or {}
train_kwargs = train_kwargs or {}
default_estimator_kwargs = {}
default_dataset_kwargs = {}
default_dataset_split_kwargs = {}
default_dataset_split_fracs = {"train": 1.0, "val": 0.0, "aux": []}
default_dataloader_kwargs = {}
default_trainer_kwargs = {}
default_train_kwargs = {}
estimator_kwargs = {**default_estimator_kwargs, **estimator_kwargs}
dataset_kwargs = {**default_dataset_kwargs, **dataset_kwargs}
dataset_split_kwargs = {**default_dataset_split_kwargs, **dataset_split_kwargs}
dataset_split_fracs = {**default_dataset_split_fracs, **dataset_split_fracs}
dataloader_kwargs = {**default_dataloader_kwargs, **dataloader_kwargs}
trainer_kwargs = {**default_trainer_kwargs, **trainer_kwargs}
train_kwargs = {**default_train_kwargs, **train_kwargs}
estimator = estimator_cls(**estimator_kwargs)
dataset = dataset_cls(**dataset_kwargs)
train_dataset, val_dataset, *aux_datasets = dataset.split(
fracs=[
dataset_split_fracs["train"],
dataset_split_fracs["val"],
*dataset_split_fracs["aux"]
],
**dataset_split_kwargs
)
train_loader = dataloader_cls(train_dataset, **dataloader_kwargs)
val_loader = dataloader_cls(val_dataset, **dataloader_kwargs)
aux_loaders = [
dataloader_cls(aux_dataset, **dataloader_kwargs)
for aux_dataset in aux_datasets
]
trainer = Trainer(
estimator,
**trainer_kwargs,
)
if load_only:
return trainer
return trainer.train(
train_loader=train_loader,
val_loader=val_loader,
aux_loaders=aux_loaders,
**train_kwargs,
)
def run_from_json(
parameters_json: str | None = None,
parameters_file: str | None = None,
) -> Trainer | Estimator:
if not (parameters_json or parameters_file):
raise ValueError("parameter json or file required")
parameters: dict[str, Any]
if parameters_json:
try:
parameters = json.loads(parameters_json)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
elif parameters_file:
try:
with open(parameters_file, encoding="utf-8") as f:
parameters = json.load(f)
except FileNotFoundError as e:
raise ValueError(f"JSON file not found: {parameters_file}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
return run(**parameters)
def handle_train(args: argparse.Namespace) -> None:
run_from_json(args.parameters_json, args.parameters_file)
def register_parser(subparsers: _SubParsersAction) -> None:
parser = subparsers.add_parser("train", help="run training loop")
parser.add_argument(
"--parameters-json",
type=str,
help="Raw JSON string with train parameters",
)
parser.add_argument(
"--parameters-file",
type=str,
help="Path to JSON file with train parameters",
)
parser.set_defaults(func=handle_train)

View File

@@ -0,0 +1,12 @@
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.text import camel_to_snake
from trainlib.dataloaders.memory import SupervisedDataLoader
_dataloaders = [
SupervisedDataLoader,
]
dataloader_map: dict[str, type[EstimatorDataLoader]] = {
camel_to_snake(_dataloader.__name__): _dataloader
for _dataloader in _dataloaders
}

View File

@@ -0,0 +1,17 @@
from torch import Tensor
from trainlib.estimator import SupervisedKwargs
from trainlib.dataloader import EstimatorDataLoader
class SupervisedDataLoader(
EstimatorDataLoader[tuple[Tensor, Tensor], SupervisedKwargs]
):
def batch_to_est_kwargs(
self,
batch_data: tuple[Tensor, Tensor]
) -> SupervisedKwargs:
return SupervisedKwargs(
inputs=batch_data[0],
labels=batch_data[1],
)

View File

@@ -0,0 +1,12 @@
from trainlib.dataset import BatchedDataset
from trainlib.utils.text import camel_to_snake
from trainlib.datasets.memory import RandomXYDataset
_datasets = [
RandomXYDataset,
]
dataset_map: dict[str, type[BatchedDataset]] = {
camel_to_snake(_dataset.__name__): _dataset
for _dataset in _datasets
}

View File

@@ -73,13 +73,33 @@ class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
from typing import Unpack from typing import Unpack
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.utils.data import TensorDataset
from trainlib.domain import SequenceDomain from trainlib.domain import SequenceDomain
from trainlib.dataset import TupleDataset, DatasetKwargs from trainlib.dataset import TupleDataset, DatasetKwargs
class RandomXYDataset(TupleDataset[Tensor]):
def __init__(
self,
num_samples: int,
input_dim: int,
output_dim: int,
**kwargs: Unpack[DatasetKwargs],
) -> None:
domain = SequenceDomain[tuple[Tensor, Tensor]](
TensorDataset(
torch.randn((num_samples, input_dim)),
torch.randn((num_samples, output_dim))
),
)
super().__init__(domain, **kwargs)
class SlidingWindowDataset(TupleDataset[Tensor]): class SlidingWindowDataset(TupleDataset[Tensor]):
def __init__( def __init__(
self, self,

View File

@@ -35,6 +35,10 @@ class EstimatorKwargs(TypedDict):
inputs: Tensor inputs: Tensor
class SupervisedKwargs(EstimatorKwargs):
labels: Tensor
class Estimator[Kw: EstimatorKwargs](nn.Module): class Estimator[Kw: EstimatorKwargs](nn.Module):
""" """
Estimator base class. Estimator base class.

View File

@@ -0,0 +1,15 @@
from trainlib.estimator import Estimator
from trainlib.utils.text import camel_to_snake
from trainlib.estimators.mlp import MLP
from trainlib.estimators.rnn import LSTM, ConvGRU, MultiheadLSTM
_estimators: list[type[Estimator]] = [
MLP,
LSTM,
MultiheadLSTM,
ConvGRU,
]
estimator_map: dict[str, type[Estimator]] = {
camel_to_snake(_estimator.__name__): _estimator
for _estimator in _estimators
}

View File

@@ -32,6 +32,7 @@ from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.map import nested_defaultdict from trainlib.utils.map import nested_defaultdict
from trainlib.dataloader import EstimatorDataLoader from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.module import ModelWrapper from trainlib.utils.module import ModelWrapper
from trainlib.utils.session import ensure_same_device
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
@@ -121,6 +122,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
loader: EstimatorDataLoader[Any, Kw], loader: EstimatorDataLoader[Any, Kw],
optimizers: tuple[Optimizer, ...], optimizers: tuple[Optimizer, ...],
max_grad_norm: float | None = None, max_grad_norm: float | None = None,
progress_bar: tqdm | None = None,
) -> list[float]: ) -> list[float]:
""" """
Train the estimator for a single epoch. Train the estimator for a single epoch.
@@ -128,32 +130,40 @@ class Trainer[I, Kw: EstimatorKwargs]:
loss_sums = [] loss_sums = []
self.estimator.train() self.estimator.train()
with tqdm(loader, unit="batch") as batches: for i, batch_kwargs in enumerate(loader):
for i, batch_kwargs in enumerate(batches): batch_kwargs = ensure_same_device(batch_kwargs, self.device)
losses = self.estimator.loss(**batch_kwargs) losses = self.estimator.loss(**batch_kwargs)
for o_idx, (loss, optimizer) in enumerate( for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True) zip(losses, optimizers, strict=True)
): ):
if len(loss_sums) <= o_idx: if len(loss_sums) <= o_idx:
loss_sums.append(0.0) loss_sums.append(0.0)
loss_sums[o_idx] += loss.item() loss_sums[o_idx] += loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
# clip gradients for optimizer's parameters # clip gradients for optimizer's parameters
if max_grad_norm is not None: if max_grad_norm is not None:
clip_grad_norm_( clip_grad_norm_(
self._get_optimizer_parameters(optimizer), self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm max_norm=max_grad_norm
) )
optimizer.step() optimizer.step()
# set loop loss to running average (reducing if multi-loss) # set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1)) loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
if progress_bar:
progress_bar.update(1)
progress_bar.set_postfix(
epoch=self._epoch,
mode="opt",
data="train",
loss=f"{loss_avg:8.2f}",
)
# step estimator hyperparam schedules # step estimator hyperparam schedules
self.estimator.epoch_step() self.estimator.epoch_step()
@@ -163,7 +173,8 @@ class Trainer[I, Kw: EstimatorKwargs]:
def _eval_epoch( def _eval_epoch(
self, self,
loader: EstimatorDataLoader[Any, Kw], loader: EstimatorDataLoader[Any, Kw],
label: str label: str,
progress_bar: tqdm | None = None,
) -> list[float]: ) -> list[float]:
""" """
Perform and record validation scores for a single epoch. Perform and record validation scores for a single epoch.
@@ -191,45 +202,53 @@ class Trainer[I, Kw: EstimatorKwargs]:
loss_sums = [] loss_sums = []
self.estimator.eval() self.estimator.eval()
with tqdm(loader, unit="batch") as batches: for i, batch_kwargs in enumerate(loader):
for i, batch_kwargs in enumerate(batches): batch_kwargs = ensure_same_device(batch_kwargs, self.device)
losses = self.estimator.loss(**batch_kwargs) losses = self.estimator.loss(**batch_kwargs)
# one-time logging
if self._epoch == 0:
self._writer.add_graph(
ModelWrapper(self.estimator), batch_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=label,
**batch_kwargs
)
loss_items = [] # once-per-session logging
for o_idx, loss in enumerate(losses): if self._epoch == 0 and i == 0:
if len(loss_sums) <= o_idx: self._writer.add_graph(
loss_sums.append(0.0) ModelWrapper(self.estimator), batch_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=label,
**batch_kwargs
)
loss_item = loss.item() loss_items = []
loss_sums[o_idx] += loss_item for o_idx, loss in enumerate(losses):
loss_items.append(loss_item) if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
# set loop loss to running average (reducing if multi-loss) loss_item = loss.item()
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1)) loss_sums[o_idx] += loss_item
batches.set_postfix(loss=f"{loss_avg:8.2f}") loss_items.append(loss_item)
# log individual loss terms after each batch # set loop loss to running average (reducing if multi-loss)
for o_idx, loss_item in enumerate(loss_items): loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
self._log_event(label, f"loss_{o_idx}", loss_item)
# log metrics for batch if progress_bar:
estimator_metrics = self.estimator.metrics(**batch_kwargs) progress_bar.update(1)
for metric_name, metric_value in estimator_metrics.items(): progress_bar.set_postfix(
self._log_event(label, metric_name, metric_value) epoch=self._epoch,
mode="eval",
data=label,
loss=f"{loss_avg:8.2f}",
)
# log individual loss terms after each batch
for o_idx, loss_item in enumerate(loss_items):
self._log_event(label, f"loss_{o_idx}", loss_item)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**batch_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._log_event(label, metric_name, metric_value)
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums] avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
@@ -240,6 +259,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
train_loader: EstimatorDataLoader[Any, Kw], train_loader: EstimatorDataLoader[Any, Kw],
val_loader: EstimatorDataLoader[Any, Kw] | None = None, val_loader: EstimatorDataLoader[Any, Kw] | None = None,
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None, aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
progress_bar: tqdm | None = None,
) -> tuple[list[float], list[float] | None, *list[float]]: ) -> tuple[list[float], list[float] | None, *list[float]]:
""" """
Evaluate estimator over each provided dataloader. Evaluate estimator over each provided dataloader.
@@ -274,12 +294,15 @@ class Trainer[I, Kw: EstimatorKwargs]:
somewhere given the many possible design choices here.) somewhere given the many possible design choices here.)
""" """
train_loss = self._eval_epoch(train_loader, "train") train_loss = self._eval_epoch(train_loader, "train", progress_bar)
val_loss = self._eval_epoch(val_loader, "val") if val_loader else None
val_loss = None
if val_loader is not None:
val_loss = self._eval_epoch(val_loader, "val", progress_bar)
aux_loaders = aux_loaders or [] aux_loaders = aux_loaders or []
aux_losses = [ aux_losses = [
self._eval_epoch(aux_loader, f"aux{i}") self._eval_epoch(aux_loader, f"aux{i}", progress_bar)
for i, aux_loader in enumerate(aux_loaders) for i, aux_loader in enumerate(aux_loaders)
] ]
@@ -433,9 +456,10 @@ class Trainer[I, Kw: EstimatorKwargs]:
self._session_name = session_name or str(int(time.time())) self._session_name = session_name or str(int(time.time()))
tblog_path = Path(self.tblog_dir, self._session_name) tblog_path = Path(self.tblog_dir, self._session_name)
self._writer = summary_writer or SummaryWriter(f"{tblog_path}") self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
progress_bar = tqdm(train_loader, unit="batch")
# evaluate model on dataloaders once before training starts # evaluate model on dataloaders once before training starts
self._eval_loaders(train_loader, val_loader, aux_loaders) self._eval_loaders(train_loader, val_loader, aux_loaders, progress_bar)
optimizers = self.estimator.optimizers(lr=lr, eps=eps) optimizers = self.estimator.optimizers(lr=lr, eps=eps)
while self._epoch < max_epochs and not self._converged( while self._epoch < max_epochs and not self._converged(
@@ -444,16 +468,21 @@ class Trainer[I, Kw: EstimatorKwargs]:
self._epoch += 1 self._epoch += 1
train_frac = f"{self._epoch}/{max_epochs}" train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...") #print(f"Training epoch {train_frac}...")
print(f"Stagnant epochs {stag_frac}...") #print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time() epoch_start_time = time.time()
self._train_epoch(train_loader, optimizers, max_grad_norm) self._train_epoch(
train_loader,
optimizers,
max_grad_norm,
progress_bar=progress_bar
)
epoch_end_time = time.time() - epoch_start_time epoch_end_time = time.time() - epoch_start_time
self._log_event("train", "epoch_duration", epoch_end_time) self._log_event("train", "epoch_duration", epoch_end_time)
train_loss, val_loss, _ = self._eval_loaders( train_loss, val_loss, *_ = self._eval_loaders(
train_loader, val_loader, aux_loaders train_loader, val_loader, aux_loaders, progress_bar
) )
# determine loss to use for measuring convergence # determine loss to use for measuring convergence
conv_loss = val_loss if val_loss else train_loss conv_loss = val_loss if val_loss else train_loss
@@ -491,6 +520,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
print(f"==== Epoch [{self._epoch}] summary ====") print(f"==== Epoch [{self._epoch}] summary ====")
for (group, name), epoch_map in self._summary.items(): for (group, name), epoch_map in self._summary.items():
for epoch, values in epoch_map.items(): for epoch, values in epoch_map.items():
# compute average over batch items recorded for the epoch
mean = torch.tensor(values).mean().item() mean = torch.tensor(values).mean().item()
self._writer.add_scalar(f"{group}-{name}", mean, epoch) self._writer.add_scalar(f"{group}-{name}", mean, epoch)
if epoch == self._epoch: if epoch == self._epoch:

View File

@@ -3,6 +3,7 @@ import random
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils import _pytree as pytree
def seed_all_backends(seed: int | Tensor | None = None) -> None: def seed_all_backends(seed: int | Tensor | None = None) -> None:
@@ -19,3 +20,9 @@ def seed_all_backends(seed: int | Tensor | None = None) -> None:
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
def ensure_same_device[T](tree: T, device: str) -> T:
return pytree.tree_map(
lambda x: x.to(device) if isinstance(x, torch.Tensor) else x,
tree,
)

View File

@@ -1,8 +1,14 @@
import re
from typing import Any from typing import Any
from colorama import Style from colorama import Style
camel2snake_regex: re.Pattern[str] = re.compile(
r"(?<!^)(?=[A-Z][a-z])|(?<=[a-z])(?=[A-Z])"
)
def camel_to_snake(text: str) -> str:
return camel2snake_regex.sub("_", text).lower()
def color_text(text: str, *colorama_args: Any) -> str: def color_text(text: str, *colorama_args: Any) -> str:
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}" return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"

View File

@@ -61,8 +61,8 @@ class SubplotsKwargs(TypedDict, total=False):
squeeze: bool squeeze: bool
width_ratios: Sequence[float] width_ratios: Sequence[float]
height_ratios: Sequence[float] height_ratios: Sequence[float]
subplot_kw: dict[str, ...] subplot_kw: dict[str, Any]
gridspec_kw: dict[str, ...] gridspec_kw: dict[str, Any]
figsize: tuple[float, float] figsize: tuple[float, float]
dpi: float dpi: float
layout: str layout: str

2
uv.lock generated
View File

@@ -1659,7 +1659,7 @@ wheels = [
[[package]] [[package]]
name = "trainlib" name = "trainlib"
version = "0.1.2" version = "0.2.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "colorama" }, { name = "colorama" },