Compare commits

...

3 Commits

22 changed files with 754 additions and 339 deletions

11
TODO.md Normal file
View File

@@ -0,0 +1,11 @@
# Long-term
- Implement a dataloader in-house, with a clear, lightweight mechanism for
collection-of-structures to structure-of-collections. For multi-proc handling
(happens in torch's dataloader, as well as the BatchedDataset for two
different purposes), we should rely on (a hopefully more stable) `execlib`.
- `Domains` may be externalized (`co3` or `convlib`)
- Up next: CLI, fully JSON-ification of model selection + train.
- Consider a "multi-train" alternative (or arg support in `train()`) for
training many "rollouts" from the same base estimator (basically forks under
different seeds). For architecture benchmarking above all, seeing average
training behavior. Consider corresponding `Plotter` methods (error bars)

29
example/example.json Normal file
View File

@@ -0,0 +1,29 @@
{
"estimator_name": "mlp",
"dataset_name": "random_xy_dataset",
"dataloader_name": "supervised_data_loader",
"estimator_kwargs": {
"input_dim": 4,
"output_dim": 2
},
"dataset_kwargs": {
"num_samples": 100000,
"preload": true,
"input_dim": 4,
"output_dim": 2
},
"dataset_split_fracs": {
"train": 0.4,
"val": 0.3,
"aux": [0.3]
},
"dataloader_kwargs": {
"batch_size": 16
},
"train_kwargs": {
"summarize_every": 20,
"max_epochs": 100,
"stop_after_epochs": 100
},
"load_only": false
}

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "trainlib" name = "trainlib"
version = "0.1.2" version = "0.3.0"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13" requires-python = ">=3.13"
authors = [ authors = [

18
trainlib/__main__.py Normal file
View File

@@ -0,0 +1,18 @@
import logging
from trainlib.cli import create_parser
def main() -> None:
parser = create_parser()
args = parser.parse_args()
# skim off log level to handle higher-level option
if hasattr(args, "log_level") and args.log_level is not None:
logging.basicConfig(level=args.log_level)
args.func(args) if "func" in args else parser.print_help()
if __name__ == "__main__":
main()

26
trainlib/cli/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
import logging
import argparse
from trainlib.cli import train
logger: logging.Logger = logging.getLogger(__name__)
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="trainlib cli",
# formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--log-level",
type=int,
metavar="int",
choices=[10, 20, 30, 40, 50],
help="Log level: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL",
)
subparsers = parser.add_subparsers(help="subcommand help")
train.register_parser(subparsers)
return parser

164
trainlib/cli/train.py Normal file
View File

@@ -0,0 +1,164 @@
import gc
import json
import argparse
from typing import Any
from argparse import _SubParsersAction
import torch
from trainlib.trainer import Trainer
from trainlib.datasets import dataset_map
from trainlib.estimator import Estimator
from trainlib.estimators import estimator_map
from trainlib.dataloaders import dataloader_map
def prepare_run() -> None:
# prepare cuda memory
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
print(f"CUDA allocated: {memory_allocated}GB")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def run(
estimator_name: str,
dataset_name: str,
dataloader_name: str,
estimator_kwargs: dict[str, Any] | None = None,
dataset_kwargs: dict[str, Any] | None = None,
dataset_split_fracs: dict[str, Any] | None = None,
dataset_split_kwargs: dict[str, Any] | None = None,
dataloader_kwargs: dict[str, Any] | None = None,
trainer_kwargs: dict[str, Any] | None = None,
train_kwargs: dict[str, Any] | None = None,
load_only: bool = False,
) -> Trainer | Estimator:
try:
estimator_cls = estimator_map[estimator_name]
except KeyError as err:
raise ValueError(
f"Invalid estimator name '{estimator_name}',"
f"must be one of {estimator_map.keys()}"
) from err
try:
dataset_cls = dataset_map[dataset_name]
except KeyError as err:
raise ValueError(
f"Invalid dataset name '{dataset_name}',"
f"must be one of {dataset_map.keys()}"
) from err
try:
dataloader_cls = dataloader_map[dataloader_name]
except KeyError as err:
raise ValueError(
f"Invalid dataloader name '{dataloader_name}',"
f"must be one of {dataloader_map.keys()}"
) from err
estimator_kwargs = estimator_kwargs or {}
dataset_kwargs = dataset_kwargs or {}
dataset_split_fracs = dataset_split_fracs or {}
dataset_split_kwargs = dataset_split_kwargs or {}
dataloader_kwargs = dataloader_kwargs or {}
trainer_kwargs = trainer_kwargs or {}
train_kwargs = train_kwargs or {}
default_estimator_kwargs = {}
default_dataset_kwargs = {}
default_dataset_split_kwargs = {}
default_dataset_split_fracs = {"train": 1.0, "val": 0.0, "aux": []}
default_dataloader_kwargs = {}
default_trainer_kwargs = {}
default_train_kwargs = {}
estimator_kwargs = {**default_estimator_kwargs, **estimator_kwargs}
dataset_kwargs = {**default_dataset_kwargs, **dataset_kwargs}
dataset_split_kwargs = {**default_dataset_split_kwargs, **dataset_split_kwargs}
dataset_split_fracs = {**default_dataset_split_fracs, **dataset_split_fracs}
dataloader_kwargs = {**default_dataloader_kwargs, **dataloader_kwargs}
trainer_kwargs = {**default_trainer_kwargs, **trainer_kwargs}
train_kwargs = {**default_train_kwargs, **train_kwargs}
estimator = estimator_cls(**estimator_kwargs)
dataset = dataset_cls(**dataset_kwargs)
train_dataset, val_dataset, *aux_datasets = dataset.split(
fracs=[
dataset_split_fracs["train"],
dataset_split_fracs["val"],
*dataset_split_fracs["aux"]
],
**dataset_split_kwargs
)
train_loader = dataloader_cls(train_dataset, **dataloader_kwargs)
val_loader = dataloader_cls(val_dataset, **dataloader_kwargs)
aux_loaders = [
dataloader_cls(aux_dataset, **dataloader_kwargs)
for aux_dataset in aux_datasets
]
trainer = Trainer(
estimator,
**trainer_kwargs,
)
if load_only:
return trainer
return trainer.train(
train_loader=train_loader,
val_loader=val_loader,
aux_loaders=aux_loaders,
**train_kwargs,
)
def run_from_json(
parameters_json: str | None = None,
parameters_file: str | None = None,
) -> Trainer | Estimator:
if not (parameters_json or parameters_file):
raise ValueError("parameter json or file required")
parameters: dict[str, Any]
if parameters_json:
try:
parameters = json.loads(parameters_json)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
elif parameters_file:
try:
with open(parameters_file, encoding="utf-8") as f:
parameters = json.load(f)
except FileNotFoundError as e:
raise ValueError(f"JSON file not found: {parameters_file}") from e
except Exception as e:
raise ValueError(f"Error loading JSON parameters: {e}") from e
return run(**parameters)
def handle_train(args: argparse.Namespace) -> None:
run_from_json(args.parameters_json, args.parameters_file)
def register_parser(subparsers: _SubParsersAction) -> None:
parser = subparsers.add_parser("train", help="run training loop")
parser.add_argument(
"--parameters-json",
type=str,
help="Raw JSON string with train parameters",
)
parser.add_argument(
"--parameters-file",
type=str,
help="Path to JSON file with train parameters",
)
parser.set_defaults(func=handle_train)

102
trainlib/dataloader.py Normal file
View File

@@ -0,0 +1,102 @@
"""
This class took me a long time to really settle into. It's a connector, and it
feels redundant in many ways, so I've nearly deleted it several times while
talking through the design. But in total, I think it serves a clear purpose.
Reasons:
- Need a typed dataloader, even if I know the type of my attached transform
- Need a new scope that uses the same base dataset without interfering with the
transform attribute; a design that sets or relies on that is subject to
conflict
- Why not just use vanilla DataLoaders?
I'd like to, but the two reasons above make it clear why this is challenging:
I don't get static checks on the structures returned during iteration, and
while you can control ad hoc data transforms via dataset ``post_transforms``,
things can get messy if you need to do that for many transforms using the
same dataset (without copying). Simplest way around this is just a new scope
with the same underlying dataset instance and a transform wrapper around the
iterator; no interference with object attributes.
This is really just meant as the minimum viable logic needed to accomplish the
above - it's a very lightweight wrapper on the base ``DataLoader`` object.
There's an explicit type upper bound ``Kw: EstimatorKwargs``, but it is
otherwise a completely general transform over dataloader batches, highlighting
that it's *mostly* here to place nice with type checks.
"""
from typing import Unpack
from collections.abc import Iterator
from torch.utils.data import DataLoader
from trainlib.dataset import BatchedDataset
from trainlib.estimator import EstimatorKwargs
from trainlib.utils.type import LoaderKwargs
class EstimatorDataLoader[B, Kw: EstimatorKwargs]:
"""
Data loaders for estimators.
This class exists to connect batched data from datasets to the expected
representation for estimator methods. Datasets may be developed
independently from a given model structures, and models should be trainable
under any such data. We need a way to ensure the batched groups of items we
get from dataloaders match on a type level, i.e., can be reshaped into the
expected ``Kw`` signature.
Note: batch structure ``B`` cannot be directly inferred from type variables
exposed by ``BatchedDatasets`` (namely ``R`` and ``I``). What's returned by
a data loader wrapping any such dataset can be arbitrary (depending on the
``collate_fn``), with default behavior being fairly consistent under nested
collections but challenging to accurately type.
.. todo::
To log (have changed for Trainer):
- New compact eval pipeline for train/val/auxiliary dataloaders.
Somewhat annoying logic, but handled consistently
- Convergence tracker will dynamically use training loss (early
stopping) when a validation set isn't provided. Same mechanics for
stagnant epochs (although early stopping is generally a little more
nuanced, having a rate-based stopper, b/c train loss generally quite
monotonic). So that's to be updated, plus room for possible model
selection strategies later.
- Logging happens at each batch, but we append to an epoch-indexed list
and later average. There was a bug in the last round of testing that
I didn't pick up where I was just overwriting summaries using the
last seen batch.
- Reworked general dataset/dataloader handling for main train loop, now
accepting objects of this class to bridge estimator and dataset
communication. This cleans up the batch mapping model.
- TODO: implement a version of this that canonically works with the
device passing plus EstimatorKwargs input; this is the last fuzzy bit
I think.
"""
def __init__(
self,
dataset: BatchedDataset,
**dataloader_kwargs: Unpack[LoaderKwargs],
) -> None:
self._dataloader = DataLoader(dataset, **dataloader_kwargs)
def batch_to_est_kwargs(self, batch_data: B) -> Kw:
"""
.. note::
Even if we have a concrete shape for the output kwarg dict for base
estimators (requiring a tensor "inputs" attribute), we don't
presuppose how a given batch object will map into this dict
structure.
return EstimatorKwargs({"inputs":0})
"""
raise NotImplementedError
def __iter__(self) -> Iterator[Kw]:
return map(self.batch_to_est_kwargs, self._dataloader)

View File

@@ -0,0 +1,12 @@
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.text import camel_to_snake
from trainlib.dataloaders.memory import SupervisedDataLoader
_dataloaders = [
SupervisedDataLoader,
]
dataloader_map: dict[str, type[EstimatorDataLoader]] = {
camel_to_snake(_dataloader.__name__): _dataloader
for _dataloader in _dataloaders
}

View File

@@ -0,0 +1,17 @@
from torch import Tensor
from trainlib.estimator import SupervisedKwargs
from trainlib.dataloader import EstimatorDataLoader
class SupervisedDataLoader(
EstimatorDataLoader[tuple[Tensor, Tensor], SupervisedKwargs]
):
def batch_to_est_kwargs(
self,
batch_data: tuple[Tensor, Tensor]
) -> SupervisedKwargs:
return SupervisedKwargs(
inputs=batch_data[0],
labels=batch_data[1],
)

View File

@@ -0,0 +1,12 @@
from trainlib.dataset import BatchedDataset
from trainlib.utils.text import camel_to_snake
from trainlib.datasets.memory import RandomXYDataset
_datasets = [
RandomXYDataset,
]
dataset_map: dict[str, type[BatchedDataset]] = {
camel_to_snake(_dataset.__name__): _dataset
for _dataset in _datasets
}

View File

@@ -73,13 +73,33 @@ class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
from typing import Unpack from typing import Unpack
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.utils.data import TensorDataset
from trainlib.domain import SequenceDomain from trainlib.domain import SequenceDomain
from trainlib.dataset import TupleDataset, DatasetKwargs from trainlib.dataset import TupleDataset, DatasetKwargs
class RandomXYDataset(TupleDataset[Tensor]):
def __init__(
self,
num_samples: int,
input_dim: int,
output_dim: int,
**kwargs: Unpack[DatasetKwargs],
) -> None:
domain = SequenceDomain[tuple[Tensor, Tensor]](
TensorDataset(
torch.randn((num_samples, input_dim)),
torch.randn((num_samples, output_dim))
),
)
super().__init__(domain, **kwargs)
class SlidingWindowDataset(TupleDataset[Tensor]): class SlidingWindowDataset(TupleDataset[Tensor]):
def __init__( def __init__(
self, self,

View File

View File

@@ -35,6 +35,10 @@ class EstimatorKwargs(TypedDict):
inputs: Tensor inputs: Tensor
class SupervisedKwargs(EstimatorKwargs):
labels: Tensor
class Estimator[Kw: EstimatorKwargs](nn.Module): class Estimator[Kw: EstimatorKwargs](nn.Module):
""" """
Estimator base class. Estimator base class.

View File

@@ -0,0 +1,15 @@
from trainlib.estimator import Estimator
from trainlib.utils.text import camel_to_snake
from trainlib.estimators.mlp import MLP
from trainlib.estimators.rnn import LSTM, ConvGRU, MultiheadLSTM
_estimators: list[type[Estimator]] = [
MLP,
LSTM,
MultiheadLSTM,
ConvGRU,
]
estimator_map: dict[str, type[Estimator]] = {
camel_to_snake(_estimator.__name__): _estimator
for _estimator in _estimators
}

View File

@@ -103,7 +103,7 @@ class MLP[Kw: MLPKwargs](Estimator[Kw]):
mae = F.l1_loss(predictions, labels).item() mae = F.l1_loss(predictions, labels).item()
return { return {
"mse": loss, # "mse": loss,
"mae": mae, "mae": mae,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }

View File

@@ -1,21 +1,21 @@
from functools import partial from typing import Any
from collections.abc import Callable from collections.abc import Callable
import numpy as np import numpy as np
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torch import Tensor from torch import Tensor
from torch.utils.data import DataLoader
from trainlib.trainer import Trainer from trainlib.trainer import Trainer
from trainlib.estimator import EstimatorKwargs from trainlib.estimator import EstimatorKwargs
from trainlib.dataloader import EstimatorDataLoader
from trainlib.utils.type import AxesArray, SubplotsKwargs from trainlib.utils.type import AxesArray, SubplotsKwargs
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None] type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
type ContextFn = Callable[[plt.Axes, str], None] type ContextFn = Callable[[plt.Axes, str], None]
class Plotter[B, K: EstimatorKwargs]: class Plotter[Kw: EstimatorKwargs]:
""" """
TODOs: TODOs:
@@ -26,14 +26,15 @@ class Plotter[B, K: EstimatorKwargs]:
intervals broken over the training epochs at 0, 50, 100, 150, ... and intervals broken over the training epochs at 0, 50, 100, 150, ... and
highlight the best one, even if that's not actually the single best highlight the best one, even if that's not actually the single best
epoch) epoch)
- Implement data and dimension limits; in the instance dataloaders have
huge numbers of samples or labels are high-dimensional
""" """
def __init__( def __init__(
self, self,
trainer: Trainer[..., K], trainer: Trainer[Any, Kw],
dataloaders: list[DataLoader], dataloaders: list[EstimatorDataLoader[Any, Kw]],
batch_estimator_map: Callable[[B, Trainer], ...], kw_to_actual: Callable[[Kw], Tensor],
estimator_to_output_map: Callable[[K], ...],
dataloader_labels: list[str] | None = None, dataloader_labels: list[str] | None = None,
) -> None: ) -> None:
self.trainer = trainer self.trainer = trainer
@@ -41,47 +42,21 @@ class Plotter[B, K: EstimatorKwargs]:
self.dataloader_labels = ( self.dataloader_labels = (
dataloader_labels or list(map(str, range(1, len(dataloaders)+1))) dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
) )
self.batch_estimator_map = batch_estimator_map self.kw_to_actual = kw_to_actual
self.estimator_to_output_map = estimator_to_output_map
self._outputs: list[list[Tensor]] | None = None self._outputs: list[list[Tensor]] | None = None
self._metrics: list[list[dict[str, float]]] | None = None self._metrics: list[list[dict[str, float]]] | None = None
self._batch_outputs_fn = partial(
self.trainer.get_batch_outputs,
batch_estimator_map=batch_estimator_map
)
self._batch_metrics_fn = partial(
self.trainer.get_batch_metrics,
batch_estimator_map=batch_estimator_map
)
self._data_tuples = None self._data_tuples = None
@property
def outputs(self) -> list[list[Tensor]]:
if self._outputs is None:
self._outputs = [
list(map(self._batch_outputs_fn, loader))
for loader in self.dataloaders
]
return self._outputs
@property
def metrics(self) -> list[list[dict[str, float]]]:
if self._metrics is None:
self._metrics = [
list(map(self._batch_metrics_fn, loader))
for loader in self.dataloaders
]
return self._metrics
@property @property
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]: def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
""" """
Produce data items; to be cached. Zip later with axes Produce data items; to be cached. Zip later with axes
""" """
self.trainer.estimator.eval()
if self._data_tuples is not None: if self._data_tuples is not None:
return self._data_tuples return self._data_tuples
@@ -89,10 +64,14 @@ class Plotter[B, K: EstimatorKwargs]:
for i, loader in enumerate(self.dataloaders): for i, loader in enumerate(self.dataloaders):
label = self.dataloader_labels[i] label = self.dataloader_labels[i]
batch = next(iter(loader)) actual = torch.cat([
est_kwargs = self.batch_estimator_map(batch, self.trainer) self.kw_to_actual(batch_kwargs).detach().cpu()
actual = self.estimator_to_output_map(est_kwargs) for batch_kwargs in loader
output = self._batch_outputs_fn(batch) ])
output = torch.cat([
self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
for batch_kwargs in loader
])
data_tuples.append((actual, output, label)) data_tuples.append((actual, output, label))
@@ -219,6 +198,14 @@ class Plotter[B, K: EstimatorKwargs]:
Note: transform samples in dataloader definitions beforehand if you Note: transform samples in dataloader definitions beforehand if you
want to change data want to change data
.. todo::
Merge in logic from general diagnostics, allowing collapse from
either dim and transposing.
Later: multi-trial error bars, or at least the ability to pass that
downstream
Parameters: Parameters:
row_size: row_size:
col_size: col_size:
@@ -270,6 +257,12 @@ class Plotter[B, K: EstimatorKwargs]:
return fig, axes return fig, axes
# def plot_ordered(...): ...
# """
# Simple ordered view of output dimensions, with actual and output
# overlaid.
# """
def plot_actual_output( def plot_actual_output(
self, self,
row_size: int | float = 2, row_size: int | float = 2,
@@ -462,3 +455,74 @@ class Plotter[B, K: EstimatorKwargs]:
combine_dims=combine_dims, combine_dims=combine_dims,
figure_kwargs=figure_kwargs, figure_kwargs=figure_kwargs,
) )
def estimator_diagnostic(
self,
row_size: int | float = 2,
col_size: int | float = 4,
session_name: str | None = None,
combine_groups: bool = False,
combine_metrics: bool = False,
transpose_layout: bool = False,
figure_kwargs: SubplotsKwargs | None = None,
) -> tuple[plt.Figure, AxesArray]:
session_map = self.trainer._event_log
session_name = session_name or next(iter(session_map))
groups = session_map[session_name]
num_metrics = len(groups[next(iter(groups))])
# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
rows = 1 if combine_groups else len(groups)
cols = 1 if combine_metrics else num_metrics
if transpose_layout:
rows, cols = cols, rows
fig, axes = self._create_subplots(
rows=rows,
cols=cols,
row_size=row_size,
col_size=col_size,
figure_kwargs=figure_kwargs,
)
if transpose_layout:
axes = axes.T
for i, group_name in enumerate(groups):
axes_row = axes[0 if combine_groups else i]
group_metrics = groups[group_name]
for j, metric_name in enumerate(group_metrics):
ax = axes_row[0 if combine_metrics else j]
metric_dict = group_metrics[metric_name]
metric_data = np.array([
(k, np.mean(v)) for k, v in metric_dict.items()
])
if combine_groups and combine_metrics:
label = f"{group_name}-{metric_name}"
title_prefix = "all"
elif combine_groups:
label = group_name
title_prefix = metric_name
# elif combine_metrics:
else:
label = metric_name
title_prefix = group_name
# else:
# label = ""
# title_prefix = f"{group_name},{metric_name}"
ax.plot(
metric_data[:, 0],
metric_data[:, 1],
label=label,
# color=colors[j],
)
ax.set_title(f"[{title_prefix}] Metrics over epochs")
ax.set_xlabel("epoch")
ax.set_ylabel("value")
ax.legend()
return fig, axes

View File

@@ -1,5 +1,15 @@
""" """
Core interface for training ``Estimators`` with ``Datasets`` Core interface for training ``Estimators`` with ``Datasets``
.. admonition:: Design of preview ``get_dataloaders()``
Note how much this method is doing, and the positivity in letting that be
more explicit elsewhere. The assignment of transforms to datasets before
wrapping as loaders is chief among these items, alongside the balancing and
splitting; I think those are hamfisted here to make it work with the old
setup, but I generally it's not consistent with the name "get dataloaders"
(i.e., and also balance and split and set transforms)
""" """
import os import os
@@ -7,48 +17,42 @@ import time
import logging import logging
from io import BytesIO from io import BytesIO
from copy import deepcopy from copy import deepcopy
from typing import Any, Self from typing import Any
from pathlib import Path from pathlib import Path
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from torch import cuda, Tensor from torch import cuda, Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from trainlib.dataset import BatchedDataset
from trainlib.estimator import Estimator, EstimatorKwargs from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.transform import Transform from trainlib.utils.map import nested_defaultdict
from trainlib.utils.type import ( from trainlib.dataloader import EstimatorDataLoader
SplitKwargs,
LoaderKwargs,
BalanceKwargs,
)
from trainlib.utils.module import ModelWrapper from trainlib.utils.module import ModelWrapper
from trainlib.utils.session import ensure_same_device
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
class Trainer[I, K: EstimatorKwargs]: class Trainer[I, Kw: EstimatorKwargs]:
""" """
Training interface for optimizing parameters of ``Estimators`` with Training interface for optimizing parameters of ``Estimators`` with
``Datasets``. ``Datasets``.
This class is generic to a dataset item type ``I`` and an estimator kwarg This class is generic to a dataset item type ``I`` and an estimator kwarg
type ``K``. These are the two primary components ``Trainer`` objects need type ``Kw``. These are the two primary components ``Trainer`` objects need
to coordinate: they ultimately rely on a provided map to ensure data items to coordinate: they ultimately rely on a provided map to ensure data items
(type ``I``) from a dataset are appropriately routed as inputs to key (type ``I``) from a dataset are appropriately routed as inputs to key
estimator methods (like ``forward()`` and ``loss()``), which accept inputs estimator methods (like ``forward()`` and ``loss()``), which accept inputs
of type ``K``. of type ``Kw``.
""" """
def __init__( def __init__(
self, self,
estimator: Estimator[K], estimator: Estimator[Kw],
device: str | None = None, device: str | None = None,
chkpt_dir: str = "chkpt/", chkpt_dir: str = "chkpt/",
tblog_dir: str = "tblog/", tblog_dir: str = "tblog/",
@@ -93,6 +97,7 @@ class Trainer[I, K: EstimatorKwargs]:
self.estimator = estimator self.estimator = estimator
self.estimator.to(self.device) self.estimator.to(self.device)
self._event_log = nested_defaultdict(4, list)
self.chkpt_dir = Path(chkpt_dir).resolve() self.chkpt_dir = Path(chkpt_dir).resolve()
self.tblog_dir = Path(tblog_dir).resolve() self.tblog_dir = Path(tblog_dir).resolve()
@@ -104,80 +109,61 @@ class Trainer[I, K: EstimatorKwargs]:
Set initial tracking parameters for the primary training loop. Set initial tracking parameters for the primary training loop.
""" """
self._epoch: int = 1 self._epoch: int = 0
self._summary = defaultdict(lambda: defaultdict(dict)) self._summary = defaultdict(lambda: defaultdict(list))
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
self._val_loss = float("inf") self._conv_loss = float("inf")
self._best_val_loss = float("inf") self._best_conv_loss = float("inf")
self._stagnant_epochs = 0 self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {} self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch( def _train_epoch(
self, self,
train_loader: DataLoader, loader: EstimatorDataLoader[Any, Kw],
batch_estimator_map: Callable[[I, Self], K],
optimizers: tuple[Optimizer, ...], optimizers: tuple[Optimizer, ...],
max_grad_norm: float | None = None, max_grad_norm: float | None = None,
progress_bar: tqdm | None = None,
) -> list[float]: ) -> list[float]:
""" """
Train the estimator for a single epoch. Train the estimator for a single epoch.
.. admonition:: On summary writers
Estimators can have several optimizers, and therefore can emit
several losses. This is a fairly unique case, but it's needed when
we want to optimize particular parameters in a particular order
(as in multi-model architectures, e.g., GANs). Point being: we
always iterate over optimizers/losses, even in the common case
where there's just a single value, and we index collections across
batches accordingly.
A few of the trackers, with the same size as the number of
optimizers:
- ``train_loss_sums``: tracks loss sums across all batches for the
epoch, used to update the loop preview text after each batch with
the current average loss
- ``train_loss_items``: collects current batch losses, recorded by
the TB writer
If there are ``M`` optimizers/losses, we log ``M`` loss terms to
the TB writer after each *batch* (not epoch). We could aggregate at
an epoch level, but parameter updates take place after each batch,
so large model changes can occur over the course of an epoch
(whereas the model remains the same over the course batch evals).
""" """
loss_sums = [] loss_sums = []
self.estimator.train() self.estimator.train()
with tqdm(train_loader, unit="batch") as batches: for i, batch_kwargs in enumerate(loader):
for i, batch_data in enumerate(batches): batch_kwargs = ensure_same_device(batch_kwargs, self.device)
est_kwargs = batch_estimator_map(batch_data, self) losses = self.estimator.loss(**batch_kwargs)
losses = self.estimator.loss(**est_kwargs)
for o_idx, (loss, optimizer) in enumerate( for o_idx, (loss, optimizer) in enumerate(
zip(losses, optimizers, strict=True) zip(losses, optimizers, strict=True)
): ):
if len(loss_sums) <= o_idx: if len(loss_sums) <= o_idx:
loss_sums.append(0.0) loss_sums.append(0.0)
loss_sums[o_idx] += loss.item() loss_sums[o_idx] += loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
# clip gradients for optimizer's parameters # clip gradients for optimizer's parameters
if max_grad_norm is not None: if max_grad_norm is not None:
clip_grad_norm_( clip_grad_norm_(
self._get_optimizer_parameters(optimizer), self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm max_norm=max_grad_norm
) )
optimizer.step() optimizer.step()
# set loop loss to running average (reducing if multi-loss) # set loop loss to running average (reducing if multi-loss)
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1)) loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
batches.set_postfix(loss=f"{loss_avg:8.2f}")
if progress_bar:
progress_bar.update(1)
progress_bar.set_postfix(
epoch=self._epoch,
mode="opt",
data="train",
loss=f"{loss_avg:8.2f}",
)
# step estimator hyperparam schedules # step estimator hyperparam schedules
self.estimator.epoch_step() self.estimator.epoch_step()
@@ -186,9 +172,9 @@ class Trainer[I, K: EstimatorKwargs]:
def _eval_epoch( def _eval_epoch(
self, self,
loader: DataLoader, loader: EstimatorDataLoader[Any, Kw],
batch_estimator_map: Callable[[I, Self], K], label: str,
loader_label: str, progress_bar: tqdm | None = None,
) -> list[float]: ) -> list[float]:
""" """
Perform and record validation scores for a single epoch. Perform and record validation scores for a single epoch.
@@ -216,55 +202,65 @@ class Trainer[I, K: EstimatorKwargs]:
loss_sums = [] loss_sums = []
self.estimator.eval() self.estimator.eval()
with tqdm(loader, unit="batch") as batches: for i, batch_kwargs in enumerate(loader):
for i, batch_data in enumerate(batches): batch_kwargs = ensure_same_device(batch_kwargs, self.device)
est_kwargs = batch_estimator_map(batch_data, self) losses = self.estimator.loss(**batch_kwargs)
losses = self.estimator.loss(**est_kwargs)
# one-time logging
if self._epoch == 0:
self._writer.add_graph(
ModelWrapper(self.estimator), est_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=loader_label,
**est_kwargs
)
loss_items = [] # once-per-session logging
for o_idx, loss in enumerate(losses): if self._epoch == 0 and i == 0:
if len(loss_sums) <= o_idx: self._writer.add_graph(
loss_sums.append(0.0) ModelWrapper(self.estimator), batch_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
self._writer,
step=self._epoch,
group=label,
**batch_kwargs
)
loss_item = loss.item() loss_items = []
loss_sums[o_idx] += loss_item for o_idx, loss in enumerate(losses):
loss_items.append(loss_item) if len(loss_sums) <= o_idx:
loss_sums.append(0.0)
# set loop loss to running average (reducing if multi-loss) loss_item = loss.item()
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1)) loss_sums[o_idx] += loss_item
batches.set_postfix(loss=f"{loss_avg:8.2f}") loss_items.append(loss_item)
# log individual loss terms after each batch # set loop loss to running average (reducing if multi-loss)
for o_idx, loss_item in enumerate(loss_items): loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
# log metrics for batch if progress_bar:
estimator_metrics = self.estimator.metrics(**est_kwargs) progress_bar.update(1)
for metric_name, metric_value in estimator_metrics.items(): progress_bar.set_postfix(
self._log_event(loader_label, metric_name, metric_value) epoch=self._epoch,
mode="eval",
data=label,
loss=f"{loss_avg:8.2f}",
)
return loss_sums # log individual loss terms after each batch
for o_idx, loss_item in enumerate(loss_items):
self._log_event(label, f"loss_{o_idx}", loss_item)
# log metrics for batch
estimator_metrics = self.estimator.metrics(**batch_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._log_event(label, metric_name, metric_value)
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
return avg_losses
def _eval_loaders( def _eval_loaders(
self, self,
loaders: list[DataLoader], train_loader: EstimatorDataLoader[Any, Kw],
batch_estimator_map: Callable[[I, Self], K], val_loader: EstimatorDataLoader[Any, Kw] | None = None,
loader_labels: list[str], aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
) -> dict[str, list[float]]: progress_bar: tqdm | None = None,
) -> tuple[list[float], list[float] | None, *list[float]]:
""" """
Evaluate estimator over each provided dataloader. Evaluate estimator over each provided dataloader.
@@ -278,42 +274,73 @@ class Trainer[I, K: EstimatorKwargs]:
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over ``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
batches. This will have no internal side effects and provides much more batches. This will have no internal side effects and provides much more
information (just aggregated losses are provided here). information (just aggregated losses are provided here).
.. admonition:: On epoch counting
Epoch counts start at 0 to allow for a sensible place to benchmark
the initial (potentially untrained/pre-trained) model before any
training data is seen. In the train loop, we increment the epoch
immediately, and all logging happens under the epoch value that's
set at the start of the iteration (rather than incrementing at the
end). Before beginning an additional training iteration, the
convergence condition in the ``while`` is effectively checking what
happened during the last epoch (the counter has not yet been
incremented); if no convergence, we begin again. (This is only
being noted because the epoch counting was previously quite
different: indexing started at ``1``, we incremented at the end of
the loop, and we didn't evaluate the model before the loop began.
This affects how we interpret plots and TensorBoard records, for
instance, so it's useful to spell out the approach clearly
somewhere given the many possible design choices here.)
""" """
return { train_loss = self._eval_epoch(train_loader, "train", progress_bar)
label: self._eval_epoch(loader, batch_estimator_map, label)
for loader, label in zip(loaders, loader_labels, strict=True)
}
def train[B]( val_loss = None
if val_loader is not None:
val_loss = self._eval_epoch(val_loader, "val", progress_bar)
aux_loaders = aux_loaders or []
aux_losses = [
self._eval_epoch(aux_loader, f"aux{i}", progress_bar)
for i, aux_loader in enumerate(aux_loaders)
]
return train_loss, val_loss, *aux_losses
def train(
self, self,
dataset: BatchedDataset[..., ..., I], train_loader: EstimatorDataLoader[Any, Kw],
batch_estimator_map: Callable[[B, Self], K], val_loader: EstimatorDataLoader[Any, Kw] | None = None,
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
*,
lr: float = 1e-3, lr: float = 1e-3,
eps: float = 1e-8, eps: float = 1e-8,
max_grad_norm: float | None = None, max_grad_norm: float | None = None,
max_epochs: int = 10, max_epochs: int = 10,
stop_after_epochs: int = 5, stop_after_epochs: int = 5,
batch_size: int = 256,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
summarize_every: int = 1, summarize_every: int = 1,
chkpt_every: int = 1, chkpt_every: int = 1,
resume_latest: bool = False,
session_name: str | None = None, session_name: str | None = None,
summary_writer: SummaryWriter | None = None, summary_writer: SummaryWriter | None = None,
aux_loaders: list[DataLoader] | None = None,
aux_loader_labels: list[str] | None = None,
) -> Estimator: ) -> Estimator:
""" """
TODO: consider making the dataloader ``collate_fn`` an explicit .. todo::
parameter with a type signature that reflects ``B``, connecting the
``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader`` - consider making the dataloader ``collate_fn`` an explicit
in-house to allow a generic around ``B`` parameter with a type signature that reflects ``B``, connecting
the ``batch_estimator_map`` somewhere. Might also re-type a
``DataLoader`` in-house to allow a generic around ``B``
- Rework the validation specification. Accept something like a
"validate_with" parameter, or perhaps just move entirely to
accepting a dataloader list, label list. You might then also need
a "train_with," and you could set up sensible defaults so you
basically have the same interaction as now. The "problem" is you
always need a train set, and there's some clearly dependent logic
on a val set, but you don't *need* val, so this should be
slightly reworked (and the more general, *probably* the better in
this case, given I want to plug into the Plotter with possibly
several purely eval sets over the model training lifetime).
Note: this method attempts to implement a general scheme for passing Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The needed items to the estimator's loss function from the dataloader. The
@@ -355,7 +382,7 @@ class Trainer[I, K: EstimatorKwargs]:
This function should map from batches - which *may* be item This function should map from batches - which *may* be item
shaped, i.e., have an ``I`` skeleton, even if stacked items may be shaped, i.e., have an ``I`` skeleton, even if stacked items may be
different on the inside - into estimator keyword arguments (type different on the inside - into estimator keyword arguments (type
``K``). Collation behavior from a DataLoader (which can be ``Kw``). Collation behavior from a DataLoader (which can be
customized) doesn't consistently yield a known type shape, however, customized) doesn't consistently yield a known type shape, however,
so it's not appropriate to use ``I`` as the callable param type. so it's not appropriate to use ``I`` as the callable param type.
@@ -416,84 +443,63 @@ class Trainer[I, K: EstimatorKwargs]:
dataset dataset
val_split_frac: fraction of dataset to use for validation val_split_frac: fraction of dataset to use for validation
chkpt_every: how often model checkpoints should be saved chkpt_every: how often model checkpoints should be saved
resume_latest: resume training from the latest available checkpoint
in the `chkpt_dir`
""" """
logger.info("> Begin train loop:") logger.info("> Begin train loop:")
logger.info(f"| > {lr=}") logger.info(f"| > {lr=}")
logger.info(f"| > {eps=}") logger.info(f"| > {eps=}")
logger.info(f"| > {max_epochs=}") logger.info(f"| > {max_epochs=}")
logger.info(f"| > {batch_size=}")
logger.info(f"| > {val_frac=}")
logger.info(f"| > {chkpt_every=}") logger.info(f"| > {chkpt_every=}")
logger.info(f"| > {resume_latest=}")
logger.info(f"| > with device: {self.device}") logger.info(f"| > with device: {self.device}")
logger.info(f"| > core count: {os.cpu_count()}") logger.info(f"| > core count: {os.cpu_count()}")
self._session_name = session_name or str(int(time.time())) self._session_name = session_name or str(int(time.time()))
tblog_path = Path(self.tblog_dir, self._session_name) tblog_path = Path(self.tblog_dir, self._session_name)
self._writer = summary_writer or SummaryWriter(f"{tblog_path}") self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
progress_bar = tqdm(train_loader, unit="batch")
aux_loaders = aux_loaders or []
aux_loader_labels = aux_loader_labels or []
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
train_loader, val_loader = self.get_dataloaders(
dataset,
batch_size,
val_frac=val_frac,
train_transform=train_transform,
val_transform=val_transform,
dataset_split_kwargs=dataset_split_kwargs,
dataset_balance_kwargs=dataset_balance_kwargs,
dataloader_kwargs=dataloader_kwargs,
)
loaders = [train_loader, val_loader, *aux_loaders]
loader_labels = ["train", "val", *aux_loader_labels]
# evaluate model on dataloaders once before training starts # evaluate model on dataloaders once before training starts
self._eval_loaders(loaders, batch_estimator_map, loader_labels) self._eval_loaders(train_loader, val_loader, aux_loaders, progress_bar)
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
while self._epoch <= max_epochs and not self._converged( while self._epoch < max_epochs and not self._converged(
self._epoch, stop_after_epochs self._epoch, stop_after_epochs
): ):
self._epoch += 1
train_frac = f"{self._epoch}/{max_epochs}" train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...") #print(f"Training epoch {train_frac}...")
print(f"Stagnant epochs {stag_frac}...") #print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time() epoch_start_time = time.time()
self._train_epoch( self._train_epoch(
train_loader, train_loader,
batch_estimator_map,
optimizers, optimizers,
max_grad_norm, max_grad_norm,
progress_bar=progress_bar
) )
epoch_end_time = time.time() - epoch_start_time epoch_end_time = time.time() - epoch_start_time
self._log_event("train", "epoch_duration", epoch_end_time) self._log_event("train", "epoch_duration", epoch_end_time)
loss_sum_map = self._eval_loaders( train_loss, val_loss, *_ = self._eval_loaders(
loaders, train_loader, val_loader, aux_loaders, progress_bar
batch_estimator_map,
loader_labels,
) )
val_loss_sums = loss_sum_map["val"] # determine loss to use for measuring convergence
self._val_loss = sum(val_loss_sums) / len(val_loader) conv_loss = val_loss if val_loss else train_loss
self._conv_loss = sum(conv_loss) / len(conv_loss)
if self._epoch % summarize_every == 0: if self._epoch % summarize_every == 0:
self._summarize() self._summarize()
if self._epoch % chkpt_every == 0: if self._epoch % chkpt_every == 0:
self.save_model() self.save_model()
self._epoch += 1
return self.estimator return self.estimator
def _converged(self, epoch: int, stop_after_epochs: int) -> bool: def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
converged = False converged = False
if epoch == 1 or self._val_loss < self._best_val_loss: if epoch == 0 or self._conv_loss < self._best_val_loss:
self._best_val_loss = self._val_loss self._best_val_loss = self._conv_loss
self._stagnant_epochs = 0 self._stagnant_epochs = 0
self._best_model_state_dict = deepcopy(self.estimator.state_dict()) self._best_model_state_dict = deepcopy(self.estimator.state_dict())
else: else:
@@ -505,110 +511,25 @@ class Trainer[I, K: EstimatorKwargs]:
return converged return converged
@staticmethod
def get_dataloaders(
dataset: BatchedDataset,
batch_size: int,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
) -> tuple[DataLoader, DataLoader]:
"""
Create training and validation dataloaders for the provided dataset.
.. todo::
Decide on policy for empty val dataloaders
"""
if dataset_split_kwargs is None:
dataset_split_kwargs = {}
if dataset_balance_kwargs is not None:
dataset.balance(**dataset_balance_kwargs)
if val_frac <= 0:
dataset.post_transform = train_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(dataset)),
"num_workers": 0,
"shuffle": True,
}
if dataloader_kwargs is not None:
train_loader_kwargs: LoaderKwargs = {
**train_loader_kwargs,
**dataloader_kwargs
}
return (
DataLoader(dataset, **train_loader_kwargs),
DataLoader(Dataset())
)
train_dataset, val_dataset = dataset.split(
[1 - val_frac, val_frac],
**dataset_split_kwargs,
)
# Dataset.split() returns light Subset objects of shallow copies of the
# underlying dataset; can change the transform attribute of both splits
# w/o overwriting
train_dataset.post_transform = train_transform
val_dataset.post_transform = val_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(train_dataset)),
"num_workers": 0,
"shuffle": True,
}
val_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(val_dataset)),
"num_workers": 0,
"shuffle": True, # shuffle to prevent homogeneous val batches
}
if dataloader_kwargs is not None:
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
return train_loader, val_loader
def _summarize(self) -> None: def _summarize(self) -> None:
""" """
Flush the training summary to the TensorBoard summary writer. Flush the training summary to the TensorBoard summary writer and print
metrics for the current epoch.
.. note:: Possibly undesirable behavior
Currently, this method aggregates metrics for the epoch summary
across all logged items *in between summarize calls*. For instance,
if I'm logging every 10 epochs, the stats at epoch=10 are actually
averages from epochs 1-10.
""" """
epoch_values = defaultdict(lambda: defaultdict(list))
for group, records in self._summary.items():
for name, steps in records.items():
for step, value in steps.items():
self._writer.add_scalar(f"{group}-{name}", value, step)
if step == self._epoch:
epoch_values[group][name].append(value)
print(f"==== Epoch [{self._epoch}] summary ====") print(f"==== Epoch [{self._epoch}] summary ====")
for group, records in epoch_values.items(): for (group, name), epoch_map in self._summary.items():
for name, values in records.items(): for epoch, values in epoch_map.items():
mean_value = torch.tensor(values).mean().item() # compute average over batch items recorded for the epoch
print( mean = torch.tensor(values).mean().item()
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}" self._writer.add_scalar(f"{group}-{name}", mean, epoch)
) if epoch == self._epoch:
print(
f"> ({len(values)}) [{group}] {name} :: {mean:.2f}"
)
self._writer.flush() self._writer.flush()
self._summary = defaultdict(lambda: defaultdict(dict)) self._summary = defaultdict(lambda: defaultdict(list))
def _get_optimizer_parameters( def _get_optimizer_parameters(
self, self,
@@ -622,33 +543,10 @@ class Trainer[I, K: EstimatorKwargs]:
] ]
def _log_event(self, group: str, name: str, value: float) -> None: def _log_event(self, group: str, name: str, value: float) -> None:
self._summary[group][name][self._epoch] = value session, epoch = self._session_name, self._epoch
self._event_log[self._session_name][group][name][self._epoch] = value
def get_batch_outputs[B]( self._summary[group, name][epoch].append(value)
self, self._event_log[session][group][name][epoch].append(value)
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> Tensor:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
output = self.estimator(**est_kwargs)[0]
output = output.detach().cpu()
return output
def get_batch_metrics[B](
self,
batch: B,
batch_estimator_map: Callable[[B, Self], K],
) -> dict[str, float]:
self.estimator.eval()
est_kwargs = batch_estimator_map(batch, self)
metrics = self.estimator.metrics(**est_kwargs)
return metrics
def save_model(self) -> None: def save_model(self) -> None:
""" """

10
trainlib/utils/map.py Normal file
View File

@@ -0,0 +1,10 @@
from collections import defaultdict
def nested_defaultdict(
depth: int,
final: type = dict,
) -> defaultdict:
if depth == 1:
return defaultdict(final)
return defaultdict(lambda: nested_defaultdict(depth - 1, final))

View File

@@ -3,6 +3,7 @@ import random
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils import _pytree as pytree
def seed_all_backends(seed: int | Tensor | None = None) -> None: def seed_all_backends(seed: int | Tensor | None = None) -> None:
@@ -19,3 +20,9 @@ def seed_all_backends(seed: int | Tensor | None = None) -> None:
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
def ensure_same_device[T](tree: T, device: str) -> T:
return pytree.tree_map(
lambda x: x.to(device) if isinstance(x, torch.Tensor) else x,
tree,
)

View File

@@ -1,8 +1,14 @@
import re
from typing import Any from typing import Any
from colorama import Style from colorama import Style
camel2snake_regex: re.Pattern[str] = re.compile(
r"(?<!^)(?=[A-Z][a-z])|(?<=[a-z])(?=[A-Z])"
)
def camel_to_snake(text: str) -> str:
return camel2snake_regex.sub("_", text).lower()
def color_text(text: str, *colorama_args: Any) -> str: def color_text(text: str, *colorama_args: Any) -> str:
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}" return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"

View File

@@ -11,7 +11,7 @@ from trainlib.dataset import BatchedDataset
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]] type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
class LoaderKwargs(TypedDict, total=False): class LoaderKwargs(TypedDict, total=False):
batch_size: int batch_size: int | None
shuffle: bool shuffle: bool
sampler: Sampler | Iterable | None sampler: Sampler | Iterable | None
batch_sampler: Sampler[list] | Iterable[list] | None batch_sampler: Sampler[list] | Iterable[list] | None
@@ -61,8 +61,8 @@ class SubplotsKwargs(TypedDict, total=False):
squeeze: bool squeeze: bool
width_ratios: Sequence[float] width_ratios: Sequence[float]
height_ratios: Sequence[float] height_ratios: Sequence[float]
subplot_kw: dict[str, ...] subplot_kw: dict[str, Any]
gridspec_kw: dict[str, ...] gridspec_kw: dict[str, Any]
figsize: tuple[float, float] figsize: tuple[float, float]
dpi: float dpi: float
layout: str layout: str

2
uv.lock generated
View File

@@ -1659,7 +1659,7 @@ wheels = [
[[package]] [[package]]
name = "trainlib" name = "trainlib"
version = "0.1.2" version = "0.2.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "colorama" }, { name = "colorama" },