initial commit

This commit is contained in:
2026-03-03 18:11:37 -08:00
commit 337175d428
24 changed files with 4940 additions and 0 deletions

20
.gitignore vendored Normal file
View File

@@ -0,0 +1,20 @@
# generic
__pycache__/
*.egg-info/
.python-version
# package-specific
.ipynb_checkpoints/
.pytest_cache/
# vendor/build files
dist/
build/
doc/_autoref/
doc/_autosummary/
doc/_build/
# misc local
/Makefile
notebooks/

105
README.md Normal file
View File

@@ -0,0 +1,105 @@
# Overview
Package summary goes here, ideally with a diagram
# Install
Installation instructions
```sh
pip install <package>
```
or as a CLI tool
```sh
uv tool install <package>
```
# Development
- Initialize/synchronize the project with `uv sync`, creating a virtual
environment with base package dependencies.
- Depending on needs, install the development dependencies with `uv sync
--extra dev`.
# Testing
- To run the unit tests, make sure to first have the test dependencies
installed with `uv sync --extra test`, then run `make test`.
- For notebook testing, run `make install-kernel` to make the environment
available as a Jupyter kernel (to be selected when running notebooks).
# Documentation
- Install the documentation dependencies with `uv sync --extra doc`.
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
locally with `docs-serve`.
# Development remarks
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a
`ParamSpec`-based typing scheme to better orchestrate alignment in the
`Trainer.train()` loop, e.g., so we can statically check whether a dataset
appears to be fulfilling the argument requirements for the estimator's
`loss()` / `metrics()` methods. Something like
```py
class Estimator[**P](nn.Module):
def loss(
self,
input: Tensor,
*args: P.args,
**kwargs: P.kwargs,
) -> Generator:
...
class Trainer[**P]:
def __init__(
self,
estimator: Estimator[P],
...
): ...
```
might be how we begin threading signatures. But ensuring dataset items can
match `P` is challenging. You can consider a "packed" object where we
obfuscate passing data through `P`-signatures:
```py
class PackedItem[**P]:
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
self._args = args
self._kwargs = kwargs
def apply[R](self, func: Callable[P, R]) -> R:
return func(*self._args, **self._kwargs)
class BatchedDataset[U, R, I, **P](Dataset):
@abstractmethod
def _process_item_data(
self,
item_data: I,
item_index: int,
) -> PackedItem[P]:
...
def __iter__(self) -> Iterator[PackedItem[P]]:
...
```
Meaningfully shaping those signatures is what remains, but you can't really
do this, not with typical type expression flexibility. For instance, if I'm
trying to appropriately type my base `TupleDataset`:
```py
class SequenceDataset[I, **P](HomogenousDataset[int, I, I, P]):
...
class TupleDataset[I](SequenceDataset[tuple[I, ...], ??]):
...
```
Here there's no way for me to shape a `ParamSpec` to indicate arbitrarily
many arguments of a fixed type (`I` in this case) to allow me to unpack my
item tuples into an appropriate `PackedItem`.
Until this (among other issues) becomes clearer, I'm setting up around a
simpler `TypedDict` type variable. We won't have particularly strong static
checks for item alignment inside `Trainer`, but this seems about as good as I
can get around the current infrastructure.

84
pyproject.toml Normal file
View File

@@ -0,0 +1,84 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "trainlib"
version = "0.1.0"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13"
authors = [
{ name="Sam Griesemer", email="git@olog.io" },
]
readme = "README.md"
license = "MIT"
keywords = [
"machine-learning",
]
classifiers = [
"Programming Language :: Python",
"Operating System :: OS Independent",
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: End Users/Desktop",
]
dependencies = [
"colorama>=0.4.6",
"matplotlib>=3.10.8",
"numpy>=2.4.1",
"tensorboard>=2.20.0",
"torch>=2.5.1",
"tqdm>=4.67.1",
]
[project.scripts]
symconf = "trainlib.__main__:main"
[project.optional-dependencies]
dev = [
"ipykernel",
]
doc = [
"furo",
"myst-parser",
"sphinx",
"sphinx-togglebutton",
"sphinx-autodoc-typehints",
]
test = [
"pytest",
]
[project.urls]
Homepage = "https://doc.olog.io/trainlib"
Documentation = "https://doc.olog.io/trainlib"
Repository = "https://git.olog.io/olog/trainlib"
Issues = "https://git.olog.io/olog/trainlib/issues"
[tool.setuptools.packages.find]
include = ["trainlib*"]
# for static data files under package root
# [tool.setuptools.package-data]
# "<package>" = ["data/*.toml"]
[tool.ruff]
line-length = 79
[tool.ruff.lint]
select = ["ANN", "E", "F", "UP", "B", "SIM", "I", "C4", "PERF"]
[tool.ruff.lint.isort]
length-sort = true
order-by-type = false
force-sort-within-sections = false
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["S101"]
"**/__init__.py" = ["F401"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
docstring-code-format = true

0
trainlib/__init__.py Normal file
View File

964
trainlib/dataset.py Normal file
View File

@@ -0,0 +1,964 @@
"""
Marginalizing out the modality layer:
With ``domain`` being an instance variable, one possible interpretation of
the object structures here is that one could completely abstract away
the domain model, defining only item structures and processing data. You
could have a single dataset definition for a particular concrete dataset,
and so long as we're talking about the same items, it can be instantiated
using *any domain*. You wouldn't need specific subclasses for disk or
network or in-memory; you can tell it directly at runtime.
That's an eventually possibility, anyway. As it stands, however, this is
effectively impossible:
You can't easily abstract the batch -> item splitting process, i.e.,
``_process_batch_data()``. A list-based version of the dataset you're
trying to define might have an individual item tuple at every index,
whereas a disk-based version might have tuples batched across a few files.
This can't reliably be inferred, nor can it be pushed to the
``Domain``-level without needing equal levels of specialization (you'd just
end up needing the exact same structural distinctions in the ``Domain``
hierarchy). So *somewhere* you need a batch splitting implementation that
is both item structure-dependent *and* domain-dependent...the question is
how dynamic you're willing to be about where it comes from. Right now, we
require this actually be defined in the ``_process_batch_data()`` method,
meaning you'll need a specific ``Dataset`` class for each domain you want
to support (e.g., ``MNISTDisk``, ``MNISTList``, ``MNISTNetwork``, etc), or
at least for each domain where "interpreting" a batch could possibly
differ. This is a case where the interface is all that enforces a
distinction: if you've got two domains that can be counted on to yield
batches in the exact same way and can use the same processing, then you
could feasibly provide ``Domain`` objects from either at runtime and have
no issues. We're "structurally blind" to any differentiation beyond the URI
and resource types by design, so two different domain implementations with
the same type signature ``Domain[U, R]`` should be expected to work fine at
runtime (again, so long as they don't also need different batch
processing), but that's not affording us much flexibility, i.e., most of
the time we'll still be defining new dataset classes for each domain.
I initially flagged this as feasible, however, because one could imagine
accepting a batch processing method upon instantiation rather than
structurally bolting it into the ``Dataset`` definition. This would require
knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so
such a function will always have to be (I, U, R)-dependent. It nevertheless
would take out some of the pain of having to define new dataset classes;
instead, you'd just need to define the batch processing method. I see this
as a worse alternative to just defining *inside* a safe context like a new
dataset class: you know the types you have to respect, and you stick that
method exactly in a context where it's understood. Freeing this up doesn't
lighten the burden of processing logic, it just changes *when* it has to be
provided, and that's not worth much (to me) in this case given the bump in
complexity. (Taking this to the extreme: you could supply *all* of an
object's methods "dynamically" and glue them together at runtime so long as
they all played nice. But wherever you were "laying them out" beforehand is
exactly the job of a class to begin with, so you don't end up with anything
more dynamic. All we're really discussing here is pushing around
unavoidable complexity inside and outside of the "class walls," and in the
particular case of ``_process_batch_data()``, it feels much better when
it's on the inside.)
Holding:
@abstractmethod
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
Get URI groups for each batch.
If there's more than one URI per batch (e.g., a data file and a
metadata file), zip the URIs such that we have a tuple of URIs per
batch.
Note that this effectively defines the index style over batches in the
attached domain. We get an ``int -> tuple[U, ...]`` map that turns
batch indices into URIs that can be read under the domain.
``get_batch()`` turns an integer index into its corresponding
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
the tuple, treating them as providers of batched data.
``_read_resources()`` passes through to the attached domain logic,
which, although common, need not supply an explicit iterable of batch
items: we just access items with ``__getitem__()`` and may ask for
``__len__``. So the returned URI group collection (this method) does
need to be iterable to measure the number of batches, but the batch
objects that are ultimately produced by these URI groups need not be
iterables themselves.
raise NotImplementedError
def _read_resources(
self,
uri_group: tuple[U, ...],
batch_index: int
) -> tuple[R, ...]:
Read batch files at the provided paths.
This method should operate on a single tuple from the list of batch
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
all of the resources for a single batch and returns a tuple of the same
size with their contents.
Note: the dependence on a batch index is mostly here to make
multi-dataset composition easier later. In-dataset, you don't need to
know the batch index to to simply process URIs, but across datasets you
need it to find out the origin of the batch (and process those URIs
accordingly).
return tuple(self.domain.read(uri) for uri in uri_group)
# pulling the type variable out of the inline generic b/c `ty` has trouble
# understanding bound type variables in subclasses (specifically with Self@)
T = TypeVar("T", bound=NamedTuple)
class NamedTupleDataset[I](Dataset):
def __init__(self, data_list: list[I]) -> None:
self.data_list = data_list
def __len__(self) -> int:
return len(self.data_list)
def __getitem__(self, index: int) -> I:
return self.data_list[index]
"""
import math
import random
import logging
from abc import abstractmethod
from copy import copy
from bisect import bisect
from typing import Unpack, TypedDict
from functools import lru_cache
from collections import defaultdict
from collections.abc import Callable, Iterator
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import Dataset
from trainlib.utils import job
from trainlib.domain import Domain, SequenceDomain
from trainlib.transform import Transform
logger: logging.Logger = logging.getLogger(__name__)
class DatasetKwargs[I](TypedDict, total=False):
pre_transform: Transform[I]
post_transform: Transform[I]
batch_cache_limit: int
preload: bool
num_workers: int
class BatchedDataset[U, R, I](Dataset):
"""
Generic dataset that dynamically pulls batched data from resources (e.g.,
files, remote locations, streams, etc).
The class is generic over a URI type ``U``, a resource type ``R`` (both of
which are used to concretize a domain ``Domain[U, R]``), and an item type
``T`` (which has a ``tuple`` upper bound).
Pipeline overview:
```
Domain -> [U] (get _batch_uris)
U -> R (domain access ; Rs provide batches)
R -> [I] (cache here ; _process_batch_data to use load_transform)
[I] -> I (human item obj ; _get_item)
I -> **P (final packed item ; __getitem__ to use transform)
```
Note^1: as far as positioning, this class is meant to play nice with
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
value add for this over the ``torch.Dataset`` base is almost entirely in
the logic it implements to map out of *batched resources* that are holding
data, and flattening it out into typical dataset items. There are also some
QoL items when it comes to splitting and balancing samples.
Note^2: even though ``Domains`` implement iterators over their URIs, this
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk
over the resources that provide data, but we don't necessarily presuppose
an ordered walk over samples within batches. Point being:
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
superclass, even when we're working around iterable ``Domains``.
Note^3: transforms are expected to operate on ``I``-items and produce
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
other intermediate representation, nor should they map from ``I`` to
something else. Point being: the dataset definition should be able to map
resources ``R`` to ``I`` without a transform: that much should be baked
into the class definition. If you find you're expecting the transform to do
that for you, you should consider pulling in some common structure across
the allowed transforms and make it a fixed part of the class.
"""
def __init__(
self,
domain: Domain[U, R],
pre_transform: Transform[I] | None = None,
post_transform: Transform[I] | None = None,
batch_cache_limit: int | None = None,
preload: bool = False,
num_workers: int = 1,
) -> None:
"""
Parameters:
pre_transform: transform to apply over items during loading (in
``_process_batch_data()``), i.e., *before* going into
persistent storage
post_transform: transform to apply just prior to returning an item
(in ``_process_item_data()``), i.e., only *after* retrieval
from persistent storage
batch_cache_limit: the max number of max batches to cache at any
one time
preload: whether to load all data into memory during instantiation
"""
self.domain = domain
self.pre_transform = pre_transform
self.post_transform = post_transform
self.batch_cache_limit = batch_cache_limit
self.num_workers = num_workers
logger.info("Fetching URIs...")
self._batch_uris: list[U] = list(domain)
self._indices: list[int] | None = None
self._dataset_len: int | None = None
self._num_batches: int = len(domain)
self.get_batch: Callable[[int], list[I]] = lru_cache(
maxsize=batch_cache_limit
)(self._get_batch)
if preload:
self.load_all(num_workers=num_workers)
@abstractmethod
def _get_dataset_len(self) -> int:
"""
Calculate the total dataset size in units of samples (not batches).
"""
raise NotImplementedError
@abstractmethod
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
"""
Return the index of the batch containing the item at the provided item
index, and the index of the item within that batch.
The behavior of this method can vary depending on what we know about
batch sizes, and should therefore be implemented by inheriting classes.
Returns:
batch_index: int
index_in_batch: int
"""
raise NotImplementedError
@abstractmethod
def _process_batch_data(
self,
batch_data: R,
batch_index: int,
) -> list[I]:
"""
Process raw domain resource data (e.g., parse JSON or load tensors) and
split accordingly, such that the returned batch is a collection of
``I`` items.
If an inheriting class wants to allow dynamic transforms, this is the
place to use a provided ``pre_transform``; the collection of items
produced by this method are cached as a batch, so results from such a
transform will be stored.
Parameters:
batch_data: tuple of resource data
batch_index: index of batch
"""
raise NotImplementedError
@abstractmethod
def _process_item_data(
self,
item_data: I,
item_index: int,
) -> I:
"""
Process individual items and produce final tuples.
If an inheriting class wants to allow dynamic transforms, this is the
place to use a provided ``post_transform``; items are pulled from the
cache (if enabled) and processed before being returned as the final
tuple outputs (so this processing is not persistent).
"""
raise NotImplementedError
def _get_item(self, item_index: int) -> I:
"""
Get the item data and zip with the item header.
Items should be the most granular representation of dataset samples
with maximum detail (i.e., yield a all available information). An
iterator over these representations across all samples can be retrieved
with `.items()`.
Note that return values from `__getitem__()` are "cleaned up" versions
of this representation, with minimal info needed for training.
"""
if item_index >= len(self):
raise IndexError
# alt indices redefine index count
item_index = self.indices[item_index]
batch_index, index_in_batch = self._get_batch_for_item(item_index)
return self.get_batch(batch_index)[index_in_batch]
def _get_batch(self, batch_index: int) -> list[I]:
"""
Return the batch data for the provided index.
Note that we require a list return type. This is where the rubber meets
the road in terms of expanding batches: the outputs here get cached, if
caching is enabled. If we were to defer batch expansion to some later
stage, that caching will be more or less worthless. For instance, if
``._process_batch_data()`` was an iterator, at best the batch
processing logic could be delayed, but then you'd either 1) cache the
iterator reference, or 2) have to further delay caching until
post-batch processing. Further, ``._process_batch_data()`` can (and
often does) depend on the entire batch, so you can't handle that
item-wise: you've got to pass all batch data in and can't act on
slices, so doing this would be irrelevant anyway.
How about iterators out from ``_read_resources()``? Then your
``_process_batch_data()`` can iterate as needed and do the work later?
The point here is that this distinction is irrelevant because you're
reading resources and processing the data here in the same method:
they're always connected, and nothing would notice if you waited
between steps. The only way this could matter is if you split the
resource reading and batch processing steps across methods, but when it
actually comes to accessing/caching the batch, you'd have to expand
any delayed reads here. There's no way around needing to see all batch
data at once here, and we don't want to make that ambiguous: ``list``
output type it is.
"""
logger.debug("Batch cache miss, reading from root...")
if batch_index >= self._num_batches:
raise IndexError
batch_uri = self._batch_uris[batch_index]
batch_data = self.domain[batch_uri]
return self._process_batch_data(batch_data, batch_index)
def load_all(self, num_workers: int | None = None) -> list[list[I]]:
"""
Preload all data batches into the cache.
Can be useful when dynamically pulling data (as it's requested) isn't
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
continually remove previous batches as they're loaded.
"""
assert self.batch_cache_limit is None, "Preloading under cache limit"
if num_workers is None:
num_workers = self.num_workers
thread_pool = ThreadPoolExecutor(max_workers=num_workers)
futures = []
for batch_index in range(self._num_batches):
future = thread_pool.submit(
self.get_batch,
batch_index,
)
futures.append(future)
job.process_futures(futures, "Loading dataset batches", "batch")
thread_pool.shutdown(wait=True)
return [future.result() for future in futures]
def split(
self,
fracs: list[float],
dataset: "BatchedDataset | None" = None,
by_attr: str | list[str | None] | None = None,
shuffle_strata: bool = True,
) -> list["BatchedDataset"]:
"""
Split dataset into fractional pieces by data attribute.
If `by_attr` is None, recovers typical fractional splitting of dataset
items, partitioning by size. Using None anywhere will index each item
into its own bucket, i.e., by its index. For instance,
- by_attr=["color"] -> {("red", 1), ("red", 2)},
{("blue", 1), ("blue", 2)}
Splits on the attribute such that each subset contains entire strata
of the attribute. "Homogeneity within clusters"
- `by_attr=["color", None]` -> {("red", 1), ("blue", 1)},
{("red", 2), ("blue", 2)}
Stratifies by attribute and then splits "by index" within, uniformly
grabbing samples across strata to form new clusters. "Homogeneity
across clusters"
Note that the final list of Subsets returned are built from shallow
copies of the underlying dataset (i.e., `self`) to allow manual
intervention with dataset attributes (e.g., setting the splits to have
different `transform`s). This is subject to possibly unexpected
behavior if re-caching data or you need a true copy of all data in
memory, but should otherwise leave most interactions unchanged.
Parameters:
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of items inside the strata can change
the actual content of the splits themselves.
"""
if by_attr == []:
raise ValueError("Cannot parse empty value list")
assert (
math.isclose(sum(fracs), 1) and sum(fracs) <= 1
), "Fractions do not sum to 1"
if isinstance(by_attr, str) or by_attr is None:
by_attr = [by_attr]
if dataset is None:
dataset = self
# dataset = DictDataset([
# self._get_item(i) for i in range(len(self))
# ])
# group samples by specified attr
attr_dict = defaultdict(list)
attr_key, by_attr = by_attr[0], by_attr[1:]
# for i in range(len(dataset)):
for i, item in enumerate(dataset.items()):
# item = dataset[i]
if attr_key is None:
attr_val = i
elif attr_key in item:
attr_val = item[attr_key]
else:
raise IndexError(f"Attribute {attr_key} not in dataset item")
attr_dict[attr_val].append(i)
if by_attr == []:
attr_keys = list(attr_dict.keys())
# shuffle keys; randomized group-level split
if shuffle_strata:
random.shuffle(attr_keys)
# considering: defer to dataloader shuffle param; should have same
# effect shuffle values; has no impact on where the split is drawn
# for attr_vals in attr_dict.values():
# random.shuffle(attr_vals)
# fractionally split over attribute keys
offset, splits = 0, []
for frac in fracs[:-1]:
frac_indices = []
frac_size = int(frac * len(attr_keys))
for j in range(offset, offset + frac_size):
j_indices = attr_dict.pop(attr_keys[j])
frac_indices.extend(j_indices)
offset += frac_size
splits.append(frac_indices)
rem_indices = []
for r_indices in attr_dict.values():
rem_indices.extend(r_indices)
splits.append(rem_indices)
subsets = []
for split in splits:
subset = copy(dataset)
subset.indices = split
subsets.append(subset)
return subsets
else:
splits = [[] for _ in range(len(fracs))]
for index_split in attr_dict.values():
# subset = Subset(dataset, index_split)
subset = copy(dataset)
subset.indices = index_split
subset_splits = self.split(
fracs, subset, by_attr, shuffle_strata
)
# unpack stratified splits
for i, subset_split in enumerate(subset_splits):
# splits[i].extend([
# index_split[s] for s in subset_split.indices
# ])
splits[i].extend(subset_split.indices)
subsets = []
for split in splits:
subset = copy(dataset)
subset.reset_indices()
subset.indices = split
subsets.append(subset)
return subsets
# considering: defer to dataloader shuffle param; should have same
# effect shuffle each split after merging; may otherwise be homogenous
# for split in splits:
# random.shuffle(split)
# return [Subset(copy(self), split) for split in splits]
def balance(
self,
dataset: "BatchedSubset[U, R, I] | None" = None,
by_attr: str | list[str | None] | None = None,
split_min_sizes: list[int] | None = None,
split_max_sizes: list[int] | None = None,
shuffle_strata: bool = True,
) -> None:
self.indices = self._balance(
dataset,
by_attr,
split_min_sizes,
split_max_sizes,
shuffle_strata,
)
def _balance(
self,
dataset: "BatchedSubset[U, R, I] | None" = None,
by_attr: str | list[str | None] | None = None,
split_min_sizes: list[int] | None = None,
split_max_sizes: list[int] | None = None,
shuffle_strata: bool = True,
) -> list[int]:
"""
Note: behavior is a little odd for nested behavior; not exactly
perfectly uniform throughout. This is a little difficult: you can't
exactly know ahead of time the size of the subgroups across splits
"""
if by_attr == []:
raise ValueError("Cannot parse empty value list")
if isinstance(by_attr, str) or by_attr is None:
by_attr = [by_attr]
if dataset is None:
dataset = BatchedSubset(self, self.indices)
if split_min_sizes == [] or split_min_sizes is None:
split_min_sizes = [0]
if split_max_sizes == [] or split_max_sizes is None:
split_max_sizes = [len(dataset)]
# group samples by specified attr
attr_dict = defaultdict(list)
attr_key, by_attr = by_attr[0], by_attr[1:]
for i in range(len(dataset)):
item = dataset[i]
if attr_key is None:
attr_val = i
elif attr_key in item:
attr_val = getattr(item, attr_key)
else:
raise IndexError(f"Attribute {attr_key} not in dataset item")
attr_dict[attr_val].append(i)
subset_splits = []
split_min_size, split_min_sizes = (
split_min_sizes[0],
split_min_sizes[1:]
)
split_max_size, split_max_sizes = (
split_max_sizes[0],
split_max_sizes[1:]
)
for split_indices in attr_dict.values():
if by_attr != []:
subset_indices = self._balance(
BatchedSubset(dataset, split_indices),
by_attr,
split_min_sizes,
split_max_sizes,
shuffle_strata,
)
split_indices = [split_indices[s] for s in subset_indices]
subset_splits.append(split_indices)
# shuffle splits; randomized group-level split
if shuffle_strata:
random.shuffle(subset_splits)
# note: split_min_size is smallest allowed, min_split_size is smallest
# observed
valid_splits = [
ss for ss in subset_splits
if len(ss) >= split_min_size
]
min_split_size = min(len(split) for split in valid_splits)
subset_indices = []
for split_indices in valid_splits:
# if shuffle_strata:
# random.shuffle(split_indices)
subset_indices.extend(
split_indices[: min(min_split_size, split_max_size)]
)
# print(f"{attr_dict.keys()=}")
# print(f"{[len(s) for s in valid_splits]=}")
# print(f"{min_split_size=}")
return subset_indices
@property
def indices(self) -> list[int]:
if self._indices is None:
self._indices = list(range(len(self)))
return self._indices
@indices.setter
def indices(self, indices: list[int]) -> None:
"""
Note: this logic facilitates nested re-indexing over the same base
dataset. The underlying data remain the same, but when indices get set,
you're effectively applying a mask over any existing indices, always
operating *relative* to the existing mask.
"""
# manually set new size
self._dataset_len = len(indices)
# note: this is a little tricky and compact; follow what happens when
# _indices aren't already set
self._indices = [
self.indices[index]
for index in indices
]
def reset_indices(self) -> None:
self._indices = None
self._dataset_len = None
def items(self) -> Iterator[I]:
for i in range(len(self)):
yield self._get_item(i)
def __len__(self) -> int:
if self._dataset_len is None:
self._dataset_len = self._get_dataset_len()
return self._dataset_len
def __getitem__(self, index: int) -> I:
item_data = self._get_item(index)
index = self.indices[index]
return self._process_item_data(item_data, index)
def __iter__(self) -> Iterator[I]:
"""
Note: this method isn't technically needed given ``__getitem__`` is
defined and we operate cleanly over integer indices 0..(N-1), so even
without an explicit ``__iter__``, Python will fall back to a reliable
iteration mechanism. We nevertheless implement the trivial logic below
to convey intent and meet static type checks for iterables.
"""
return (self[i] for i in range(len(self)))
class CompositeBatchedDataset[U, R, I](BatchedDataset[U, R, I]):
"""
Dataset class for wrapping individual datasets.
Note: because this remains a valid ``BatchedDataset``, we re-thread the
generic type variables through the set of composed datasets. That is, they
must have a common domain type ``Domain[U, R]``.
"""
def __init__(
self,
datasets: list[BatchedDataset[U, R, I]],
) -> None:
"""
Parameters:
datasets: list of datasets
"""
self.datasets = datasets
self._indices: list[int] | None = None
self._dataset_len: int | None = None
self._item_psum: list[int] = []
self._batch_psum: list[int] = []
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
if not arr:
return []
prefix_sum = [0] * (len(arr) + 1)
for i in range(len(arr)):
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
return prefix_sum
def _get_dataset_for_item(self, item_index: int) -> tuple[int, int]:
dataset_index = bisect(self._item_psum, item_index) - 1
index_in_dataset = item_index - self._item_psum[dataset_index]
return dataset_index, index_in_dataset
def _get_dataset_for_batch(self, batch_index: int) -> tuple[int, int]:
dataset_index = bisect(self._batch_psum, batch_index) - 1
index_in_dataset = batch_index - self._batch_psum[dataset_index]
return dataset_index, index_in_dataset
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
index_pair = self._get_dataset_for_item(item_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._get_batch_for_item(index_in_dataset)
def _get_dataset_len(self) -> int:
self.load_all()
dataset_batch_sizes = [len(dataset) for dataset in self.datasets]
dataset_sizes = [dataset._num_batches for dataset in self.datasets]
# this method will only be ran once; set this instance var
self._item_psum = self._compute_prefix_sum(dataset_batch_sizes)
self._batch_psum = self._compute_prefix_sum(dataset_sizes)
return self._item_psum[-1]
def _process_batch_data(
self,
batch_data: R,
batch_index: int,
) -> list[I]:
index_pair = self._get_dataset_for_item(batch_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._process_batch_data(batch_data, index_in_dataset)
def _process_item_data(self, item_data: I, item_index: int) -> I:
index_pair = self._get_dataset_for_item(item_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._process_item_data(item_data, index_in_dataset)
def _get_item(self, item_index: int) -> I:
item_index = self.indices[item_index]
index_pair = self._get_dataset_for_item(item_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._get_item(index_in_dataset)
def _get_batch(self, batch_index: int) -> list[I]:
index_pair = self._get_dataset_for_item(batch_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._get_batch(index_in_dataset)
def load_all(
self,
num_workers: int | None = None
) -> list[list[I]]:
batches = []
for dataset in self.datasets:
batches.extend(dataset.load_all(num_workers))
return batches
@property
def pre_transform(self) -> list[Transform[I] | None]:
return [dataset.pre_transform for dataset in self.datasets]
@pre_transform.setter
def pre_transform(self, pre_transform: Transform[I]) -> None:
for dataset in self.datasets:
dataset.pre_transform = pre_transform
@property
def post_transform(self) -> list[Transform[I] | None]:
return [dataset.post_transform for dataset in self.datasets]
@post_transform.setter
def post_transform(self, post_transform: Transform[I]) -> None:
for dataset in self.datasets:
dataset.post_transform = post_transform
class BatchedSubset[U, R, I](BatchedDataset[U, R, I]):
def __init__(
self,
dataset: BatchedDataset[U, R, I],
indices: list[int],
) -> None:
self.dataset = dataset
self.indices = indices
def _get_item(self, item_index: int) -> I:
"""
Subset indices are "reset" in its context. Simply passes through
"""
return self.dataset._get_item(self.indices[item_index])
def _get_dataset_len(self) -> int:
return len(self.indices)
class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
"""
Batched dataset where batches are equally sized.
Subclass from this base when you can count on the reference data being
prepared with fixed size batches (up to the last batch), e.g., that which
has been prepared with a `Packer`. This can greatly improve measurement
time of the dataset size by preventing the need for reading all batch files
upfront, and reduces the cost of identifying item batches from O(log n) to
O(1).
Methods left for inheriting classes:
- ``_process_item_data()``: item processing
- ``_process_batch_data()``: batch processing
"""
def __init__(
self,
domain: Domain[U, R],
**kwargs: Unpack[DatasetKwargs],
) -> None:
super().__init__(domain, **kwargs)
# determine batch size across dataset, along w/ possible partial final
# batch
bsize = rem = len(self.get_batch(self._num_batches - 1))
if self._num_batches > 1:
bsize, rem = len(self.get_batch(self._num_batches - 2)), bsize
self._batch_size: int = bsize
self._batch_rem: int = rem
def _get_dataset_len(self) -> int:
return self._batch_size * (self._num_batches - 1) + self._batch_rem
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
return item_index // self._batch_size, item_index % self._batch_size
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
"""
Batched dataset where batches have arbitrary size.
Methods left for inheriting classes:
- ``_process_item_data()``: item processing
- ``_process_batch_data()``: batch processing
"""
def __init__(
self,
domain: Domain[U, R],
**kwargs: Unpack[DatasetKwargs],
) -> None:
super().__init__(domain, **kwargs)
self._batch_size_psum: list[int] = []
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
if not arr:
return []
prefix_sum = [0] * (len(arr) + 1)
for i in range(len(arr)):
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
return prefix_sum
def _get_dataset_len(self) -> int:
# type error below: no idea why this is flagged
batches = self.load_all()
batch_sizes = [len(batch) for batch in batches]
# this method will only be ran once; set this instance var
self._batch_size_psum = self._compute_prefix_sum(batch_sizes)
return self._batch_size_psum[-1]
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
batch_index = bisect(self._batch_size_psum, item_index) - 1
index_in_batch = item_index - self._batch_size_psum[batch_index]
return batch_index, index_in_batch
class SequenceDataset[I](HomogenousDataset[int, I, I]):
"""
Trivial dataset skeleton for sequence domains.
``I``-typed sequence items map directly to dataset items. To produce a
fully concrete dataset, one still needs to define ``_process_item_data()``
to map from ``I``-items to tuples.
"""
domain: SequenceDomain[I]
def _process_batch_data(
self,
batch_data: I,
batch_index: int,
) -> list[I]:
if self.pre_transform is not None:
batch_data = self.pre_transform(batch_data)
return [batch_data]
class TupleDataset[T](SequenceDataset[tuple[T, ...]]):
"""
Trivial sequence-of-tuples dataset.
This is the most straightforward line to a concrete dataset from a
``BatchedDataset`` base class. That is: the underlying domain is a sequence
whose items are mapped to single-item batches and are already tuples.
"""
def _process_item_data(
self,
item_data: tuple[T, ...],
item_index: int,
) -> tuple[T, ...]:
if self.post_transform is not None:
item_data = self.post_transform(item_data)
return item_data

View File

179
trainlib/datasets/disk.py Normal file
View File

@@ -0,0 +1,179 @@
from io import BytesIO
from abc import abstractmethod
from typing import Any, NamedTuple
from pathlib import Path
from zipfile import ZipFile
from mema.dataset import HomogenousDataset
from mema.domains.disk import DiskDomain
class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
"""
The following line is to satisfy the type checker, which
1. Can't recognize an appropriately re-typed constructor arg like
def __init__(
self,
domain: DiskDomain,
...
): ...
This *does* match the parent generic for the U=Path, R=bytes context
def __init__(
self,
domain: Domain[U, R],
...
): ...
but the type checker doesn't see this.
2. "Lifted" type variables out of generics can't be used as upper bounds,
at least not without throwing type checker warnings (thanks to PEP695).
So I'm not allowed to have
```
class BatchedDataset[U, R, D: Domain[U, R]]:
...
```
which could bring appropriately dynamic typing for ``Domain``s, but is
not a sufficiently concrete upper bound.
So: we settle for a class-level type declaration, which despite not being
technically appropriately scoped, it's not harming anything and satisfies
``ty`` type checks downstream (e.g., when we access ``DiskDomain.root``.
"""
domain: DiskDomain
class PackedDataset(DiskDataset):
"""
Packed dataset.
Currently out of commission - not compatible with latest dataset
definitions. Will require a zipped disk domain
Requires a specific dataset storage structure on the root data path:
<data-path>/data/*-i<batch-num>-b<batch-size>
<data-path>/meta/*-i<batch-num>-b<batch-size>
That is, all data are compacted into core data (`data/`) and metadata
(`meta/`) subdirectories. Compatible out-of-the-box with datasets written
with a `Packer`.
"""
def _get_uri_groups(self) -> list[tuple[Path, ...]]:
data_root = Path(self.domain.root, "data")
meta_root = Path(self.domain.root, "meta")
data_file_paths = data_root.iterdir()
meta_file_paths = meta_root.iterdir()
return list(zip(data_file_paths, meta_file_paths, strict=True))
def _process_batch_data(
self,
batch_data: tuple[bytes, ...],
batch_index: int,
) -> list[tuple[bytes, ...]]:
data_bytes, meta_bytes = batch_data
meta_batch = self._unpack_meta(meta_bytes)
data_batch = self._unpack_data(data_bytes, meta_batch)
# zip up batch partial batch items into a single batch iterable
# composed of item tuples
batch_items = [
(*ba, *bm) # pyre-ignore[60]
for ba, bm in zip(data_batch, meta_batch, strict=True)
]
# apply transform to batch items if provided
if self.load_transform:
batch_items = list(map(self.load_transform, batch_items))
return batch_items
@abstractmethod
def _unpack_data(
self,
batch_data_bytes: bytes,
batch_meta: list[tuple[Any, ...]],
) -> list[tuple[Any, ...]]:
"""
Load and unpack batch data.
This method should be the inverse of an affiliated
`Packer`'s `pack_data_bytes()`:
<data list> -> pack -> to bytes -> blob
<data list> <- unpack <- from bytes <- blob
Returns:
iterable of (partial) item tuples w/ data content
"""
raise NotImplementedError
@abstractmethod
def _unpack_meta(self, batch_meta_bytes: bytes) -> list[tuple[Any, ...]]:
"""
Load and unpack batch metadata.
This method should be the inverse of an affiliated
`Packer`'s `pack_meta_bytes()`:
<data list> -> pack -> to bytes -> blob
<data list> <- unpack <- from bytes <- blob
Returns:
iterable of (partial) item tuples w/ meta content
"""
raise NotImplementedError
class ZippedDataset(DiskDataset):
"""
Dataset with samples stored in ZIP files.
This dataset base is primarily used as the type of input dataset for a
`Packer` object. This is compatible with most raw dataset structures,
reading down arbitrarily packaged ZIP files and "re-batching" during
access.
"""
item_header: tuple[str, ...] = ("bytes",)
def _get_uri_groups(self) -> list[tuple[str, ...]]:
zip_file_paths: list[str] = [
str(path)
for path in Path(self.data_path).iterdir()
if path.suffix == ".zip"
]
# will just zip a single list yielding 1-tuples (to match type sig)
return list(zip(zip_file_paths))
def _process_batch_data(
self,
batch_data: tuple[bytes, ...],
batch_index: int,
) -> list[tuple[bytes, ...]]:
items = []
batch_zip_file = ZipFile(BytesIO(batch_data[0]))
for zname in batch_zip_file.namelist():
if Path(zname).suffix not in self.extensions:
continue
with batch_zip_file.open(zname, "r") as zfile:
items.append((zfile.read(), zname))
return list(items)

210
trainlib/datasets/memory.py Normal file
View File

@@ -0,0 +1,210 @@
"""
Leaving the following here in case we return for some specifics in the future.
At an earlier stage, we had a few specifically typed datasets/domains for
in-memory structures, but these are almost entirely redundant with the generic
``SequenceDomain`` definition: the general retrieval and iteration behaviors
there are fairly universal and can be type-tailored without new class
definitions (unless you really want a new hierarchy).
The following were design notes on a dict/named tuple list-based dataset.
Dataset from list of (dict) records.
This is designed such that a "batch" is just a single record in the base
list. Each URI group is a singleton index tuple, which grabs its single
corresponding record in the domain when read.
One could alternatively have the entire record list be a single batch, and
have a single URI group with N null references. The null reference would
need to be the sole URI for the domain definition, and ``.read(null)``
would always return the entire list. Everything will then be handled as
expected: in ``get_item()``,
- ``._get_batch_for_item(item_index)`` maps to the singleton batch
- ``.get_batch(batch_index)[index_in_batch]`` simply indexes directly in
the record list
This is pretty unnatural, since now a batch as returned by
``.read_resources()`` will be N references to the entire record list. It
nevertheless sidesteps full list reconstruction and allows the propagation
of direct indexing, so it has that in its corner.
Another viable approach (least preferred): have a single batch, but where
URIs map to individual row indices. In some respects, this is the most
intuitive interpretation of "batch" and how we'd map to items, but it's the
*least efficient* because the batch logic of ``BatchedDataset`` will
*reconstruct the entire record list*. In ``.read_resources()``, we have
```
return tuple(self.domain.read(uri) for uri in uri_group)
```
So, despite being a natural interpretation, it pulls apart the domain "too
well" and rebuilds it in-house. This makes sense when resources are
external (e.g., files on disk, data over network), but when already
part of the Python process, that model is wasteful.
Note: inheriting datasets will need to implement an appropriate
``_item_header``. This will be trivial for consistent datasets, since it'll
just be the keys of any of the dict records.
Left to define:
+ ``_process_item_data()``: map T to final tuple
class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
domain: TupleListDomain[T]
def _process_batch_data(
self,
batch_data: T,
batch_index: int,
) -> list[T]:
Produce collection of item tuples.
Note: no interaction with ``item_tuple`` is needed given batches are
already singular ``item_tuple`` shaped items.
return [batch_data]
"""
from typing import Unpack
import torch.nn.functional as F
from torch import Tensor
from trainlib.domain import SequenceDomain
from trainlib.dataset import TupleDataset, DatasetKwargs
class SlidingWindowDataset[T: Tensor](TupleDataset[T]):
def __init__(
self,
domain: SequenceDomain[tuple[T, ...]],
lookback: int,
offset: int = 0,
lookahead: int = 1,
num_windows: int = 1,
**kwargs: Unpack[DatasetKwargs],
) -> None:
self.lookback = lookback
self.offset = offset
self.lookahead = lookahead
self.num_windows = num_windows
super().__init__(domain, **kwargs)
def _process_batch_data(
self,
batch_data: tuple[T, ...],
batch_index: int,
) -> list[tuple[T, ...]]:
"""
Backward pads first sequence over (lookback-1) length, and steps the
remaining items forward by the lookahead.
Batch data:
(Tensor[C1, T], ..., Tensor[CN, T])
+ lookback determines window size; pad left to create lookback size
with the first element at the right:
|-lookback-|
[0 ... 0 T0]
`lookback` is strictly positive, unbounded.
+ offset shifts the first such window forward. `0 <= offset < L`;
think of it as the number of additional non-padded items we
slide into the window from the right. At its largest value of `L-1`,
we'll have L-sized windows with `T0` as the leftmost element.
offset=2
[0 ... T0 T1 T2]
|---lookback---|
In effect, the index of the rightmost item of the first window
will be equal to the value of `offset`. There are `T - offset`
total possible windows over a given sequence (regardless of
lookback).
+ lookahead determines the offset of the "label" slices from the
first index, regardless of any value of `offset`.
lookahead=3
[0 ... 0 T0] T1 T2 [T3]
|-lookback-|
0 [0 .. T0 T1] T2 T3 [T4]
|-lookback-|
There are `T - lookahead` allowed slices, assuming the lookahead
exceeds the offset.
To get windows starting with the first index at the left: we first set
out window size (call it L), determined by `lookback`. Then the
rightmost index we want will be `L-1`, which determines our `offset`
setting.
lookahead=L, offset=L-1
[ T_0 ... T_{L-1} ]
To get a one-step lookahead in front of that rightmost item, the
`lookahead` can be set to the index of the first label we want:
lookback=L, offset=L-1 lookahead=L
[ T_0 ... T_{L-1} ] [ T_L ]
"""
if self.pre_transform is not None:
batch_data = self.pre_transform(batch_data)
lb = self.lookback
off = min(self.offset, lb-1)
la = self.lookahead
ws = []
for t in batch_data[:self.num_windows]:
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
# the amount of our offset, which in the extreme cases does no
# padding.
xip = F.pad(t, ((lb-1) - off, 0))
# extract sliding windows over the padded tensor
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for
# `lb`-sized windows. We turn (C_i, pad+T) shape into
# (C_i, T-offset, lb), giving `T-offset` total `lb` windows.
wi = xip.unfold(-1, lb, 1)
# (C_i, T-offset, lb) -> (T-offset, C_i, lb)
wi = wi.permute(1, 0, 2)
# if lookahead exceeds offset, there are some windows for which we
# won't be able to assign a lookahead label. Cut those off here
if la - off > 0:
wi = wi[:-(la-off)]
ws.append(wi)
ys = []
for t in batch_data[self.num_windows:]:
# tensors (C_i, T) shaped, align with lookahead, giving a
# (C_i, T-lookahead)
y = t[:, la:]
# (C_i, T-lookahead) -> (T-lookahead, C_i)
y = y.permute(1, 0)
# cut off any elements if offset exceeds lookahead
if off - la > 0:
y = y[:-(off-la)]
ys.append(y)
return list(zip(*ws, *ys, strict=True))

85
trainlib/domain.py Normal file
View File

@@ -0,0 +1,85 @@
"""
Defines a knowledge domain. Wraps a Dataset / Simulator / Knowledge
Downstream exploration might include
- Calibrating Simulator / Knowledge with a Dataset
- Amending Dataset with Simulator / Knowledge
- Positioning Knowledge within Simulator context
* Where to replace Simulator subsystem with Knowledge?
Other variations:
- Multi-fidelity simulators
- Multi-scale models
- Multi-system
- Incomplete knowledge / divergence among sources
Questions:
- Should Simulator / Knowledge be unified as one (e.g., "Expert")
"""
from collections.abc import Mapping, Iterator, Sequence
class Domain[U, R](Mapping[U, R]):
"""
Domain base class, generic to a URI type ``U`` and resource type ``R``.
Domains are just Mappings where the iterator behavior is specifically typed
to range over keys (URIs). Defining a specific class here gives us a base
for a nominal hierarchy, but functionally the Mapping core (sized iterables
with accessors).
"""
def __call__(self, uri: U) -> R:
"""
Get the resource for a given URI (call-based alias).
"""
return self[uri]
def __getitem__(self, uri: U) -> R:
"""
Get the resource for a given URI.
"""
raise NotImplementedError
def __iter__(self) -> Iterator[U]:
"""
Provide an iterator over domain URIs.
"""
raise NotImplementedError
def __len__(self) -> int:
"""
Measure the size the domain.
"""
raise NotImplementedError
class SequenceDomain[R](Domain[int, R]):
"""
Trivial domain implementation for wrapping sequences that can be seen as
Mappings with 0-indexed keys.
Why define this? Domains provide iterators over their *keys*, sequences
often iterate over *values*.
"""
def __init__(self, sequence: Sequence[R]) -> None:
self.sequence = sequence
def __getitem__(self, uri: int) -> R:
return self.sequence[uri]
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.sequence)))
def __len__(self) -> int:
return len(self.sequence)

View File

37
trainlib/domains/disk.py Normal file
View File

@@ -0,0 +1,37 @@
from pathlib import Path
from collections.abc import Iterator
from mema.domain import Domain
class DiskDomain(Domain[Path, bytes]):
def __init__(
self,
root: Path,
extensions: list[str] | None = None,
) -> None:
"""
Parameters:
extensions: list of file extensions to filter for when determining
data file paths to read. This is a whitelist, except when left
as ``None`` (which is default), in which case all extensions
all allowed.
"""
self.root = root
self.extensions = extensions
def __getitem__(self, uri: Path) -> bytes:
return uri.read_bytes()
def __iter__(self) -> Iterator[Path]:
return (
path for path in self.root.iterdir()
if path.is_file() and (
self.extensions is None
or path.suffix in self.extensions
)
)
def __len__(self) -> int:
return sum(1 for _ in iter(self))

View File

@@ -0,0 +1,58 @@
from collections.abc import Callable, Iterator, Sequence
from trainlib.domain import Domain
class SimulatorDomain[P, R](Domain[int, R]):
"""
Base simulator domain, generic to arbitrary callables.
Note: we don't store simulation results here; that's left to a downstream
object, like a `BatchedDataset`, to cache if needed. We also don't subclass
`SequenceDataset` because the item getter type doesn't align: we accept an
`int` in the parameter list, but don't return the items directly from that
collection (we transform them first).
"""
def __init__(
self,
simulator: Callable[[P], R],
parameters: Sequence[P],
) -> None:
self.simulator = simulator
self.parameters = parameters
def __getitem__(self, uri: int) -> R:
return self.simulator(self.parameters[uri])
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.parameters)))
def __len__(self) -> int:
return len(self.parameters)
class SimulatorPredictiveDomain[P, R](Domain[int, tuple[R, ...]]):
def __init__(
self,
simulator: Callable[[P], R],
parameters: Sequence[P],
predictives: Sequence[Callable[[R], R]],
) -> None:
self.simulator = simulator
self.parameters = parameters
self.predictives = predictives
def __getitem__(self, uri: int) -> tuple[R, ...]:
sample = self.simulator(self.parameters[uri])
return (
sample,
*(p(sample) for p in self.predictives)
)
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.parameters)))
def __len__(self) -> int:
return len(self.parameters)

195
trainlib/estimator.py Normal file
View File

@@ -0,0 +1,195 @@
"""
Development note
I'd rather lay out bare args and kwargs in the estimator methods, but the
allowed variance in subclasses makes this difficult. To ease the pain of typing
coordination with datasets, `loss()` / `forward()` / `metrics()` specify their
arguments through a TypedDict. This makes those signatures more portable, and
can bound type variables in other places.
In theory, even the single tensor input base currently in place could be
relaxed given the allowed variability in dataset / dataloader outputs. The
default collate function in the PyTorch `DataLoader` source leaves types like
strings and bytes unchanged, so not all kinds of data are batched into a
`Sequence`, let alone a tensor. Nevertheless, estimators are `nn.Module`
derivatives, so it's at minimum a safe assumption we'll need at least one
tensor (or *should* have one).
"""
import logging
from typing import Unpack, TypedDict
from collections.abc import Generator
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from trainlib.util.type import OptimizerKwargs
logger: logging.Logger = logging.getLogger(__name__)
class EstimatorKwargs(TypedDict):
inputs: Tensor
class Estimator[Kw: EstimatorKwargs](nn.Module):
"""
Estimator base class.
All methods that raise ``NotImplementedErrors`` are directly invoked in the
``Trainer.train(...)`` loop.
Note the flexibility afforded to the signatures of `forward()`, `loss()`,
and `metrics()` methods, which should generally have identical sets of
arguments in inheriting classes. The base class is generic to a type `K`,
which should be a `TypedDict` for inheriting classes (despite not being
enforceable as an upper bound) that reflects the keyword argument
structure for these methods.
For instance, in a sequence prediction model with labels and masking, you
might have something like:
.. code-block:: python
class PredictorKwargs(TypedDict, total=False):
labels: list[Tensor]
lengths: Tensor
mask: Tensor
class SequencePredictor(Estimator[PredictorKwargs]):
def forward(
self,
input: Tensor,
**kwargs: Unpack[PredictorKwargs],
) -> tuple[Tensor, ...]:
...
def loss(
self,
input: Tensor,
**kwargs: Unpack[PredictorKwargs],
) -> Generator[Tensor]:
...
def metrics(
self,
input: Tensor,
**kwargs: Unpack[PredictorKwargs],
) -> dict[str, float]:
...
While `loss` and `metrics` should leverage the full set of keyword
arguments, `forward` may not (e.g., in the example above, it shouldn't use
`labels`).
Subclasses of `SequencePredictor` should then be generic over subtypes of
`PredictorKwargs`.
"""
def forward(
self,
**kwargs: Unpack[Kw],
) -> tuple[Tensor, ...]:
raise NotImplementedError
def loss(
self,
**kwargs: Unpack[Kw],
) -> Generator:
"""
Compute model loss for the given input.
Note that the loss is implemented as a generator to support
multi-objective estimator setups. That is, losses can be yielded in
sequence, allowing for a training loop to propagate model parameter
updates before the next loss function is calculated. For instance, in a
GAN-like setup, one might first emit the D-loss, update D parameters,
then compute the G-loss (*depending* on the updated D parameters). Such
a scheme is not otherwise possible without bringing the intermediate
parameter update "in house" (breaking a separation of duties with the
train loop).
"""
raise NotImplementedError
def metrics(
self,
**kwargs: Unpack[Kw],
) -> dict[str, float]:
"""
Compute metrics for the given input.
"""
raise NotImplementedError
def optimizers(
self,
**kwargs: Unpack[OptimizerKwargs],
) -> tuple[Optimizer, ...]:
"""
Get optimizers for the estimator to use in training loops.
Example providing a singular Adam-based optimizer:
.. code-block:: python
def optimizers(...):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=1e-3,
eps=1e-8,
)
return (optimizer,)
"""
raise NotImplementedError
def epoch_step(self) -> None:
"""
Step epoch-dependent model state.
This method should not include optimization of primary model
parameters; that should be left to external optimizers. Instead, this
method should step forward things like internal hyperparameter
schedules in accordance with the expected call rate (e.g., every
epoch).
"""
raise NotImplementedError
def epoch_write(
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[Kw],
) -> None:
"""
Write epoch-dependent Tensorboard items.
If implemented, this should supplement that which is provided in
``metrics()``. Tensors provided as ``input`` should include raw
training/validation data; examples include writing raw embeddings,
canonical visualizations of the samples (e.g., previewing images), etc.
Parameters:
input: batch of tensors from the current epoch
writer: tensorboard writer instance
step: current step in the optimization loop
val: whether input is a validation sample
"""
raise NotImplementedError
def log_arch(self) -> None:
"""
Log Estimator architecture details.
"""
logger.info(f"> Estimator :: {self.__class__.__name__}")
num_params = sum(
p.numel() for p in self.parameters() if p.requires_grad
)
logger.info(f"| > # model parameters: {num_params}")

View File

491
trainlib/estimators/rnn.py Normal file
View File

@@ -0,0 +1,491 @@
import logging
from typing import Unpack, NotRequired
from collections.abc import Generator
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from mema.estimator import Estimator, EstimatorKwargs
from mema.util.type import OptimizerKwargs
from mema.util.module import get_grad_norm
from mema.estimators.tdnn import TDNNLayer
logger: logging.Logger = logging.getLogger(__name__)
class RNNKwargs(EstimatorKwargs):
inputs: Tensor
labels: NotRequired[Tensor]
class LSTM[K: RNNKwargs](Estimator[K]):
"""
Base RNN architecture.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 64,
num_layers: int = 4,
bidirectional: bool = False,
verbose: bool = True,
) -> None:
"""
Parameters:
input_dim: dimensionality of the input
output_dim: dimensionality of the output
rnn_dim: dimensionality of each RNN layer output
num_layers: number of LSTM layers pairs to use
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dense_in = nn.Linear(input_dim, hidden_dim)
self.lstm = nn.LSTM(
hidden_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
)
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
self.dense_z = nn.Linear(lstm_out_dim, output_dim)
# weight initialization for LSTM layers
def init_weights(m: nn.Module) -> None:
if isinstance(m, nn.LSTM):
for name, p in m.named_parameters():
if "weight_ih" in name:
nn.init.xavier_uniform_(p)
elif "weight_hh" in name:
nn.init.orthogonal_(p)
elif "bias" in name:
nn.init.zeros_(p)
self.apply(init_weights)
if verbose:
self.log_arch()
def _clamp_rand(self, x: Tensor) -> Tensor:
return torch.clamp(
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
min=-1.0,
max=1.0,
)
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# data shaped (B, C, T); map to (B, T, C)
x = inputs.permute(0, 2, 1)
x = torch.tanh(self.dense_in(x))
x = self._clamp_rand(x)
x, hidden = self.lstm(x)
z = self.dense_z(x)
return z[:, -1, :], hidden
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
yield F.mse_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
return {
"loss": loss,
"grad_norm": get_grad_norm(self)
}
def optimizers(
self,
**kwargs: Unpack[OptimizerKwargs],
) -> tuple[Optimizer, ...]:
"""
"""
default_kwargs: Unpack[OptimizerKwargs] = {
"lr": 1e-3,
"eps": 1e-8,
}
opt_kwargs = {**default_kwargs, **kwargs}
optimizer = torch.optim.AdamW(
self.parameters(),
**opt_kwargs,
)
return (optimizer,)
def epoch_step(self) -> None:
return None
def epoch_write(
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[K],
) -> None:
return None
def log_arch(self) -> None:
super().log_arch()
logger.info(f"| > {self.input_dim=}")
logger.info(f"| > {self.hidden_dim=}")
logger.info(f"| > {self.num_layers=}")
logger.info(f"| > {self.output_dim=}")
class MultiheadLSTMKwargs(EstimatorKwargs):
inputs: Tensor
labels: NotRequired[Tensor]
auxiliary: NotRequired[Tensor]
class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 64,
num_layers: int = 4,
bidirectional: bool = False,
head_dims: list[int] | None = None,
verbose: bool = True,
) -> None:
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.head_dims = head_dims if head_dims is not None else []
self.dense_in = nn.Linear(input_dim, hidden_dim)
self.lstm = nn.LSTM(
hidden_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
)
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
self.dense_z_out = nn.Linear(lstm_out_dim, output_dim)
self.dense_z_heads = nn.ModuleList([
nn.Linear(lstm_out_dim, head_dim)
for head_dim in self.head_dims
])
# weight initialization for LSTM layers
def init_weights(m: nn.Module) -> None:
if isinstance(m, nn.LSTM):
for name, p in m.named_parameters():
if "weight_ih" in name:
nn.init.xavier_uniform_(p)
elif "weight_hh" in name:
nn.init.orthogonal_(p)
elif "bias" in name:
nn.init.zeros_(p)
self.apply(init_weights)
if verbose:
self.log_arch()
def _clamp_rand(self, x: Tensor) -> Tensor:
return torch.clamp(
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
min=-1.0,
max=1.0,
)
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# data shaped (B, C, T); map to (B, T, C)
x = inputs.permute(0, 2, 1)
x = torch.tanh(self.dense_in(x))
x = self._clamp_rand(x)
x, hidden = self.lstm(x)
z = self.dense_z_out(x)
zs = torch.cat([head(x) for head in self.dense_z_heads], dim=-1)
return z[:, -1, :], zs[:, -1, :]
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
pred, pred_aux = self(**kwargs)
labels = kwargs["labels"]
aux_labels = kwargs.get("auxiliary")
if aux_labels is None:
yield F.mse_loss(pred, labels)
else:
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
return {
"loss": loss,
"grad_norm": get_grad_norm(self)
}
def optimizers(
self,
**kwargs: Unpack[OptimizerKwargs],
) -> tuple[Optimizer, ...]:
"""
"""
default_kwargs: Unpack[OptimizerKwargs] = {
"lr": 1e-3,
"eps": 1e-8,
}
opt_kwargs = {**default_kwargs, **kwargs}
optimizer = torch.optim.AdamW(
self.parameters(),
**opt_kwargs,
)
return (optimizer,)
def epoch_step(self) -> None:
return None
def epoch_write(
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[K],
) -> None:
return None
def log_arch(self) -> None:
super().log_arch()
logger.info(f"| > {self.input_dim=}")
logger.info(f"| > {self.hidden_dim=}")
logger.info(f"| > {self.num_layers=}")
logger.info(f"| > {self.output_dim=}")
class ConvRNN[K: RNNKwargs](Estimator[K]):
"""
Base recurrent convolutional architecture.
Computes latents, initial states, and rate estimates from features and
lambda parameter.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
temporal_dim: int,
gru_dim: int = 64,
conv_dim: int = 96,
num_layers: int = 4,
conv_kernel_sizes: list[int] | None = None,
conv_dilations: list[int] | None = None,
verbose: bool = True,
) -> None:
"""
Parameters:
input_dim: dimensionality of the input
output_dim: dimensionality of the output
gru_dim: dimensionality of each GRU layer output
conv_dim: dimensionality of each conv layer output
num_layers: number of gru-conv layer pairs to use
conv_kernel_sizes: kernel sizes for conv layers
conv_dilations: dilation settings for conv layers
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.gru_dim = gru_dim
self.conv_dim = conv_dim
self.num_layers = num_layers
self.receptive_field = 0
self.conv_kernel_sizes: list[int]
if conv_kernel_sizes is None:
self.conv_kernel_sizes = [4] * num_layers
else:
self.conv_kernel_sizes = conv_kernel_sizes
self.conv_dilations: list[int]
if conv_dilations is None:
self.conv_dilations = [1] + [2] * (num_layers - 1)
else:
self.conv_dilations = conv_dilations
self._gru_layers: nn.ModuleList = nn.ModuleList()
self._conv_layers: nn.ModuleList = nn.ModuleList()
layer_in_dim = gru_dim
for i in range(self.num_layers):
gru_layer = nn.GRU(layer_in_dim, gru_dim, batch_first=True)
self._gru_layers.append(gru_layer)
layer_in_dim += gru_dim
tdnn_layer = TDNNLayer(
layer_in_dim,
conv_dim,
kernel_size=self.conv_kernel_sizes[i],
dilation=self.conv_dilations[i],
#pad=False,
)
self.receptive_field += tdnn_layer.receptive_field
self._conv_layers.append(tdnn_layer)
layer_in_dim += conv_dim
# self.dense_in = nn.Linear(self.input_dim, gru_dim)
self.dense_in = TDNNLayer(
self.input_dim,
gru_dim,
kernel_size=1,
pad=False
)
# will be (B, T, C), applies indep at each time step across channels
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
# will be (B, C, T), applies indep at each time step across channels
self.dense_z = TDNNLayer(
layer_in_dim,
self.output_dim,
kernel_size=temporal_dim,
pad=False,
)
# weight initialization for GRU layers
def init_weights(module: nn.Module) -> None:
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith("weight_hh_"):
nn.init.orthogonal_(p[1])
self.apply(init_weights)
if verbose:
self.log_arch()
def _clamp_rand(self, x: Tensor) -> Tensor:
return torch.clamp(
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
min=-1.0,
max=1.0,
)
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# embedding shaped (B, C, T)
x = self._clamp_rand(torch.tanh(self.dense_in(inputs)))
# prepare shape (B, T, C) -- for GRU
x = x.transpose(-2, -1)
for gru, conv in zip(self._gru_layers, self._conv_layers, strict=True):
xg = self._clamp_rand(gru(x)[0])
x = torch.cat([x, xg], -1)
xc = self._clamp_rand(conv(x.transpose(-2, -1)))
xc = xc.transpose(-2, -1)
x = torch.cat([x, xc], -1)
# z = self.dense_z(x)
# z = z.transpose(-2, -1)
x = x.transpose(-2, -1)
# map to (B, C, T)
z = self.dense_z(x)
return (z,)
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
# squeeze last dim; we've mapped T -> 1
predictions = predictions.squeeze(-1)
yield F.mse_loss(predictions, labels, reduction="mean")
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
return {
"loss": loss,
"grad_norm": get_grad_norm(self)
}
def optimizers(
self,
**kwargs: Unpack[OptimizerKwargs],
) -> tuple[Optimizer, ...]:
"""
"""
default_kwargs: Unpack[OptimizerKwargs] = {
"lr": 1e-3,
"eps": 1e-8,
}
opt_kwargs = {**default_kwargs, **kwargs}
optimizer = torch.optim.AdamW(
self.parameters(),
**opt_kwargs,
)
return (optimizer,)
def epoch_step(self) -> None:
return None
def epoch_write(
self,
writer: SummaryWriter,
step: int | None = None,
val: bool = False,
**kwargs: Unpack[K],
) -> None:
return None
def log_arch(self) -> None:
super().log_arch()
logger.info(f"| > {self.input_dim=}")
logger.info(f"| > {self.gru_dim=}")
logger.info(f"| > {self.conv_dim=}")
logger.info(f"| > {self.num_layers=}")
logger.info(f"| > {self.conv_kernel_sizes=}")
logger.info(f"| > {self.conv_dilations=}")
logger.info(f"| > {self.receptive_field=}")
logger.info(f"| > {self.output_dim=}")

114
trainlib/estimators/tdnn.py Normal file
View File

@@ -0,0 +1,114 @@
import logging
from collections.abc import Generator
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.nn.utils.parametrizations import weight_norm
logger: logging.Logger = logging.getLogger(__name__)
class TDNNLayer(nn.Module):
"""
Time delay neural network layer.
Built on torch Conv1D layers, with additional support for automatic
padding.
"""
def __init__(
self,
input_channels: int,
output_channels: int,
kernel_size: int = 3,
dilation: int = 1,
lookahead: int = 0,
pad: bool = True,
) -> None:
"""
Implements a fast TDNN layer via `torch.Conv1d`.
Note that we're restricted to the kernel shapes producible by `Conv1d`
objects, implying (primarily) that the kernel must be symmetric and
have equal spacing.
For example, the symmetric but non-uniform context [-3, -2, 0, +2, +3]
cannot be represented, while [-6, -3, 0, 3, 6] can. A few other kernel
examples:
kernel_size=3; dilation=1 -> [-1, 0, 1]
kernel_size=3; dilation=3 -> [-3, 0, 3]
kernel_size=4; dilation=2 -> [-3, -1, 1, 3]
By default, the TDNN layer left pads the input to ensure the output has
the same sequence length as the original input. For example, with a
kernel size of 3 (dilation of 1) and sequence length of T=3, the
sequence will be left padded with 2 zeros:
[0, 0, 1, 1, 1] -> [x1, x2, x3]
If a lookahead is specified, some number of those left zeros will be
moved to the right. If lookahead=1, for instance, indicating 1
additional "future frame" of context, padding will look like
[0, 1, 1, 1, 0] -> [x1, x2, x3]
The output x_i now sees through time step i+1.
Parameters:
input_channels: number of input channels
output_channels: number of channels produced by the temporal
convolution
kernel_size: total size of the kernel
dilation: dilation of receptive field, i.e., the size of the gaps
lookahead: number of allowed lookahead frames
pad: whether the input should be padded, producing an output with
the same sequence length as the input
"""
super().__init__()
self.td_conv: nn.Module = weight_norm(
nn.Conv1d(
input_channels,
output_channels,
kernel_size=kernel_size,
dilation=dilation,
)
)
self.pad: bool = pad
self.lookahead = lookahead
self.receptive_field: int = (kernel_size - 1) * dilation + 1
assert (
self.lookahead < self.receptive_field
), "Lookahead cannot exceed receptive field"
def forward(self, x: Tensor) -> Tensor:
"""
Dimension definitions:
- B: batch size
- Di: input dimension at single time step (aka input channels)
- Do: output dimension at single time step (aka output channels)
- T: sequence length
Parameters:
x: input tensor, shaped [B, Di, T] (optionally w/o a batch dim)
Returns:
tensor shaped [B, Do, T-kernel_size]
"""
# pad according to receptive field and lookahead s.t. output
# shape [B, *, T]
if self.pad:
x = F.pad(
x,
(self.receptive_field - self.lookahead - 1, self.lookahead)
)
return self.td_conv(x)

509
trainlib/trainer.py Normal file
View File

@@ -0,0 +1,509 @@
import os
import time
import logging
from io import BytesIO
from copy import deepcopy
from typing import Any, Self
from pathlib import Path
from collections import defaultdict
from collections.abc import Callable
import torch
from tqdm import tqdm
from torch import cuda, Tensor
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from trainlib.dataset import BatchedDataset
from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.transform import Transform
from trainlib.util.type import (
SplitKwargs,
LoaderKwargs,
BalanceKwargs,
)
from trainlib.util.module import ModelWrapper
logger: logging.Logger = logging.getLogger(__name__)
class Trainer[I, K: EstimatorKwargs]:
"""
Training interface for updating ``Estimators`` with ``Datasets``.
"""
def __init__(
self,
estimator: Estimator[K],
device: str | None = None,
chkpt_dir: str = "chkpt/",
tblog_dir: str = "tblog/",
) -> None:
"""
Parameters:
estimator: `Estimator` model object
device: device on which to carry out training
"""
self.device: str
if device is None:
self.device = "cuda" if cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"> Trainer device: {self.device}")
if self.device.startswith("cuda"):
if torch.cuda.is_available():
# extra cuda details
logger.info(f"| > {cuda.device_count()=}")
logger.info(f"| > {cuda.current_device()=}")
logger.info(f"| > {cuda.get_device_name()=}")
logger.info(f"| > {cuda.get_device_capability()=}")
# memory info (in GB)
gb = 1024**3
memory_allocated = cuda.memory_allocated() / gb
memory_reserved = cuda.memory_reserved() / gb
memory_total = cuda.get_device_properties(0).total_memory / gb
logger.info("| > CUDA memory:")
logger.info(f"| > {memory_total=:.2f}GB")
logger.info(f"| > {memory_reserved=:.2f}GB")
logger.info(f"| > {memory_allocated=:.2f}GB")
else:
logger.warning("| > CUDA device specified but not available")
else:
logger.info("| > Using CPU device - no additional device info")
self.estimator = estimator
self.estimator.to(self.device)
self.chkpt_dir = Path(chkpt_dir).resolve()
self.tblog_dir = Path(tblog_dir).resolve()
self.reset()
def reset(self) -> None:
"""
Set base tracking parameters.
"""
self._step: int = 0
self._epoch: int = 0
self._summary: dict[str, list[tuple[float, int]]] = defaultdict(list)
self._val_loss = float("inf")
self._best_val_loss = float("inf")
self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {}
def train(
self,
dataset: BatchedDataset[..., ..., I],
batch_estimator_map: Callable[[I, Self], K],
lr: float = 1e-3,
eps: float = 1e-8,
max_grad_norm: float | None = None,
max_epochs: int = 10,
stop_after_epochs: int = 5,
batch_size: int = 256,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
summarize_every: int = 1,
chkpt_every: int = 1,
resume_latest: bool = False,
summary_writer: SummaryWriter | None = None,
) -> Estimator:
"""
Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The
abstract `Estimator` base only requires the model output be provided
for any given loss calculation, but concrete estimators will often
require additional arguments (e.g., labels or length masks, as
is the case with sequential models). In any case, this method defers
any further logic to the `loss` method of the underlying estimator, so
one should take care to synchronize the sample structure with `dataset`
to match that expected by `self.estimator.loss(...)`.
On batch_estimator_map:
Dataloader collate functions are responsible for mapping a collection
of items into an item of collections, roughly speaking. If items are
tuples of tensors,
[
( [1, 1], [1, 1] ),
( [2, 2], [2, 2] ),
( [3, 3], [3, 3] ),
]
the collate function maps back into the item skeleton, producing a
single tuple of (stacked) tensors
( [[1, 1],
[2, 2],
[3, 3]],
[[1, 1],
[2, 2],
[3, 3]] )
This function should map from batches (which should be *item shaped*,
i.e., have an `I` skeleton, even if stacked items may be different on
the inside) into estimator keyword arguments (type `K`).
Parameters:
lr: learning rate (default: 1e-3)
eps: adam EPS (default: 1e-8)
max_epochs: maximum number of training epochs
stop_after_epochs: number of epochs with stagnant validation losses
to allow before early stopping. If training stops earlier, the
parameters for the best recorded validation score are loaded
into the estimator before the method returns. If
`stop_after_epochs >= max_epochs`, the estimator will train
over all epochs and return as is, irrespective of validation
scores.
batch_size: size of batch to use when training on the provided
dataset
val_split_frac: fraction of dataset to use for validation
chkpt_every: how often model checkpoints should be saved
resume_latest: resume training from the latest available checkpoint
in the `chkpt_dir`
"""
logger.info("> Begin train loop:")
logger.info(f"| > {lr=}")
logger.info(f"| > {eps=}")
logger.info(f"| > {max_epochs=}")
logger.info(f"| > {batch_size=}")
logger.info(f"| > {val_frac=}")
logger.info(f"| > {chkpt_every=}")
logger.info(f"| > {resume_latest=}")
logger.info(f"| > with device: {self.device}")
logger.info(f"| > core count: {os.cpu_count()}")
writer: SummaryWriter
dir_prefix = str(int(time.time()))
if summary_writer is None:
writer = SummaryWriter(f"{self.tblog_dir}")
else:
writer = summary_writer
train_loader, val_loader = self.get_dataloaders(
dataset,
batch_size,
val_frac=val_frac,
train_transform=train_transform,
val_transform=val_transform,
dataset_split_kwargs=dataset_split_kwargs,
dataset_balance_kwargs=dataset_balance_kwargs,
dataloader_kwargs=dataloader_kwargs,
)
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
self._step = 0
self._epoch = 1 # start from 1 for logging convenience
while self._epoch <= max_epochs and not self._converged(
self._epoch, stop_after_epochs
):
print(f"Training epoch {self._epoch}/{max_epochs}...")
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...")
epoch_start_time = time.time()
train_loss_sums = []
self.estimator.train()
with tqdm(train_loader, unit="batch") as train_epoch:
for i, batch_data in enumerate(train_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
inputs = est_kwargs["inputs"]
# one-time logging
if self._step == 0:
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=False,
**est_kwargs
)
train_losses = self.estimator.loss(**est_kwargs)
train_loss_items = []
for o_idx, optimizer in enumerate(optimizers):
optimizer.zero_grad()
train_loss = next(train_losses)
if len(train_loss_sums) <= o_idx:
train_loss_sums.append(0.0)
train_loss_item = train_loss.item()
train_loss_sums[o_idx] += train_loss_item
train_loss_items.append(train_loss_item)
train_loss.backward()
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
opt_params = self._get_optimizer_parameters(optimizer)
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
optimizer.step()
self._step += len(inputs)
for train_loss_item, train_loss_sum in zip(
train_loss_items,
train_loss_sums,
strict=True,
):
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
self._add_summary_item("train_loss", train_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"train_{metric_name}", metric_value)
self.estimator.epoch_step()
for li, train_loss_sum in enumerate(train_loss_sums):
self._add_summary_item(
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
)
if val_frac > 0:
val_loss_sums = []
self.estimator.eval()
with tqdm(val_loader, unit="batch") as val_epoch:
for i, batch_data in enumerate(val_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
inputs = est_kwargs["inputs"]
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=True,
**est_kwargs
)
val_losses = self.estimator.loss(**est_kwargs)
val_loss_items = []
for o_idx in range(len(optimizers)):
val_loss = next(val_losses)
if len(val_loss_sums) <= o_idx:
val_loss_sums.append(0.0)
val_loss_item = val_loss.item()
val_loss_sums[o_idx] += val_loss_item
val_loss_items.append(val_loss_item)
for val_loss_item, val_loss_sum in zip(
val_loss_items,
val_loss_sums,
strict=True,
):
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
self._add_summary_item("val_loss", val_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"val_{metric_name}", metric_value)
for li, val_loss_sum in enumerate(val_loss_sums):
self._add_summary_item(
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
)
# convergence of multiple losses may be ambiguous
self._val_loss = sum(val_loss_sums) / len(val_loader)
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
if self._epoch % summarize_every == 0:
self._summarize(writer, self._epoch)
# save checkpoint
if self._epoch % chkpt_every == 0:
self.save_model(
self._epoch, self.chkpt_dir, dir_prefix
)
self._epoch += 1
return self.estimator
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
converged = False
if epoch == 1 or self._val_loss < self._best_val_loss:
self._best_val_loss = self._val_loss
self._stagnant_epochs = 0
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
else:
self._stagnant_epochs += 1
if self._stagnant_epochs >= stop_after_epochs:
self.estimator.load_state_dict(self._best_model_state_dict)
converged = True
return converged
@staticmethod
def get_dataloaders(
dataset: BatchedDataset,
batch_size: int,
val_frac: float = 0.1,
train_transform: Transform | None = None,
val_transform: Transform | None = None,
dataset_split_kwargs: SplitKwargs | None = None,
dataset_balance_kwargs: BalanceKwargs | None = None,
dataloader_kwargs: LoaderKwargs | None = None,
) -> tuple[DataLoader, DataLoader]:
"""
Create training and validation dataloaders for the provided dataset.
"""
if dataset_split_kwargs is None:
dataset_split_kwargs = {}
if dataset_balance_kwargs is not None:
dataset.balance(**dataset_balance_kwargs)
if val_frac <= 0:
dataset.post_transform = train_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(dataset)),
"num_workers": 0,
"shuffle": True,
}
if dataloader_kwargs is not None:
train_loader_kwargs: LoaderKwargs = {
**train_loader_kwargs,
**dataloader_kwargs
}
return (
DataLoader(dataset, **train_loader_kwargs),
DataLoader(Dataset())
)
train_dataset, val_dataset = dataset.split(
[1 - val_frac, val_frac],
**dataset_split_kwargs,
)
# Dataset.split() returns light Subset objects of shallow copies of the
# underlying dataset; can change the transform attribute of both splits
# w/o overwriting
train_dataset.post_transform = train_transform
val_dataset.post_transform = val_transform
train_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(train_dataset)),
"num_workers": 0,
"shuffle": True,
}
val_loader_kwargs: LoaderKwargs = {
"batch_size": min(batch_size, len(val_dataset)),
"num_workers": 0,
"shuffle": True, # shuffle to prevent homogeneous val batches
}
if dataloader_kwargs is not None:
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
return train_loader, val_loader
def _summarize(self, writer: SummaryWriter, epoch: int) -> None:
"""
Flush the training summary to the TB summary writer.
"""
summary_values = defaultdict(list)
for name, records in self._summary.items():
for value, step in records:
writer.add_scalar(name, value, step)
summary_values[name].append(value)
print(f"==== Epoch [{epoch}] summary ====")
for name, values in summary_values.items():
mean_value = torch.tensor(values).mean().item()
print(f"> ({len(values)}) {name} :: {mean_value:.2f}")
writer.flush()
self._summary = defaultdict(list)
def _get_optimizer_parameters(
self,
optimizer: torch.optim.Optimizer,
) -> list[Tensor]:
return [
param
for param_group in optimizer.param_groups
for param in param_group["params"]
if param.grad is not None
]
def _add_summary_item(self, name: str, value: float) -> None:
self._summary[name].append((value, self._step))
def save_model(
self,
epoch: int,
chkpt_dir: str | Path,
dir_prefix: str,
) -> None:
"""
Save a model checkpoint.
"""
model_buff = BytesIO()
torch.save(self.estimator.state_dict(), model_buff)
model_buff.seek(0)
model_class = self.estimator.__class__.__name__
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
chkpt_dir = Path(chkpt_dir, dir_prefix)
chkpt_path = Path(chkpt_dir, chkpt_name)
chkpt_dir.mkdir(parents=True, exist_ok=True)
chkpt_path.write_bytes(model_buff.getvalue())
def load_model(
self,
epoch: int,
chkpt_dir: str,
) -> None:
"""
Load a model checkpoint from a given epoch.
Note that this assumes the model was saved via `Trainer.save_model()`,
and the estimator provided to this `Trainer` instance matches the
architecture of the checkpoint model being loaded.
"""
model_class = self.estimator.__class__.__name__
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
chkpt_path = Path(chkpt_dir, chkpt_name)
model_buff = BytesIO(chkpt_path.read_bytes())
model_buff.seek(0)
model_dict = torch.load(model_buff, weights_only=True)
self.estimator.load_state_dict(model_dict)

11
trainlib/transform.py Normal file
View File

@@ -0,0 +1,11 @@
class Transform[I]:
"""
Dataset transform base class.
In places that directly reference a base ``Transform[I]``, a hint
``Callable[[I], I]`` would suffice. This class exists to allow nominal
checks for purpose-built transforms.
"""
def __call__(self, item: I) -> I:
raise NotImplementedError

View File

54
trainlib/utils/job.py Normal file
View File

@@ -0,0 +1,54 @@
import logging
import concurrent
from concurrent.futures import Future, as_completed
from tqdm import tqdm
from colorama import Fore, Style
from mema.util.text import color_text
logger: logging.Logger = logging.getLogger(__name__)
def process_futures(
futures: list[Future],
desc: str | None = None,
unit: str | None = None,
) -> None:
if desc is None:
desc = "Awaiting futures"
if unit is None:
unit = "it"
success = 0
cancelled = 0
errored = 0
submitted = len(futures)
progress_bar = tqdm(
total=len(futures),
desc=f"{desc} [submitted {len(futures)}]",
unit=unit,
)
for future in as_completed(futures):
try:
future.result()
success += 1
except concurrent.futures.CancelledError as e:
cancelled += 1
logger.error(f'Future cancelled; "{e}"')
except Exception as e:
errored += 1
logger.warning(f'Future failed with unknown exception "{e}"')
suc_txt = color_text(f"{success}", Fore.GREEN)
can_txt = color_text(f"{cancelled}", Fore.YELLOW)
err_txt = color_text(f"{errored}", Fore.RED)
tot_txt = color_text(f"{success+cancelled+errored}", Style.BRIGHT)
progress_bar.set_description(
f"{desc} [{tot_txt} / {submitted} | {suc_txt} {can_txt} {err_txt}]"
)
progress_bar.update(n=1)
progress_bar.close()

25
trainlib/utils/module.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
from torch import nn
class ModelWrapper(nn.Module):
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
#def forward(self, inputs, **kwargs):
#return self.model(**{"inputs": inputs, **kwargs})
def forward(self, kwargs):
return self.model(**kwargs)
def get_grad_norm(model: nn.Module, p: int = 2) -> float:
norm = 0
for param in model.parameters():
if not param.requires_grad or param.grad is None:
continue
grad_item = torch.abs(param.grad).pow(p).sum().item()
norm += float(grad_item)
return norm ** (1 / p)

8
trainlib/utils/text.py Normal file
View File

@@ -0,0 +1,8 @@
from typing import Any
from colorama import Style
def color_text(text: str, *colorama_args: Any) -> str:
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"

52
trainlib/utils/type.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Any, TypedDict
from collections.abc import Callable, Iterable
from torch import Tensor
from torch.utils.data.sampler import Sampler
from mema.dataset import BatchedDataset
class LoaderKwargs(TypedDict, total=False):
batch_size: int
shuffle: bool
sampler: Sampler | Iterable | None
batch_sampler: Sampler[list] | Iterable[list] | None
num_workers: int
collate_fn: Callable[[list], Any]
pin_memory: bool
drop_last: bool
timeout: float
worker_init_fn: Callable[[int], None]
multiprocessing_context: object
generator: object
prefetch_factor: int
persistent_workers: bool
pin_memory_device: str
in_order: bool
class SplitKwargs(TypedDict, total=False):
dataset: BatchedDataset | None
by_attr: str | list[str | None] | None
shuffle_strata: bool
class BalanceKwargs(TypedDict, total=False):
by_attr: str | list[str | None] | None
split_min_sizes: list[int] | None
split_max_sizes: list[int] | None
shuffle_strata: bool
class OptimizerKwargs(TypedDict, total=False):
lr: float | Tensor
betas: tuple[float | Tensor, float | Tensor]
eps: float
weight_decay: float
amsgrad: bool
maximize: bool
foreach: bool | None
capturable: bool
differentiable: bool
fused: bool | None

1739
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff