initial commit
This commit is contained in:
20
.gitignore
vendored
Normal file
20
.gitignore
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# generic
|
||||||
|
__pycache__/
|
||||||
|
*.egg-info/
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# package-specific
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# vendor/build files
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
doc/_autoref/
|
||||||
|
doc/_autosummary/
|
||||||
|
doc/_build/
|
||||||
|
|
||||||
|
# misc local
|
||||||
|
/Makefile
|
||||||
|
notebooks/
|
||||||
|
|
||||||
105
README.md
Normal file
105
README.md
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# Overview
|
||||||
|
Package summary goes here, ideally with a diagram
|
||||||
|
|
||||||
|
# Install
|
||||||
|
Installation instructions
|
||||||
|
|
||||||
|
```sh
|
||||||
|
pip install <package>
|
||||||
|
```
|
||||||
|
|
||||||
|
or as a CLI tool
|
||||||
|
|
||||||
|
```sh
|
||||||
|
uv tool install <package>
|
||||||
|
```
|
||||||
|
|
||||||
|
# Development
|
||||||
|
- Initialize/synchronize the project with `uv sync`, creating a virtual
|
||||||
|
environment with base package dependencies.
|
||||||
|
- Depending on needs, install the development dependencies with `uv sync
|
||||||
|
--extra dev`.
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
- To run the unit tests, make sure to first have the test dependencies
|
||||||
|
installed with `uv sync --extra test`, then run `make test`.
|
||||||
|
- For notebook testing, run `make install-kernel` to make the environment
|
||||||
|
available as a Jupyter kernel (to be selected when running notebooks).
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
- Install the documentation dependencies with `uv sync --extra doc`.
|
||||||
|
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
|
||||||
|
locally with `docs-serve`.
|
||||||
|
|
||||||
|
# Development remarks
|
||||||
|
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a
|
||||||
|
`ParamSpec`-based typing scheme to better orchestrate alignment in the
|
||||||
|
`Trainer.train()` loop, e.g., so we can statically check whether a dataset
|
||||||
|
appears to be fulfilling the argument requirements for the estimator's
|
||||||
|
`loss()` / `metrics()` methods. Something like
|
||||||
|
|
||||||
|
```py
|
||||||
|
class Estimator[**P](nn.Module):
|
||||||
|
def loss(
|
||||||
|
self,
|
||||||
|
input: Tensor,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Generator:
|
||||||
|
...
|
||||||
|
|
||||||
|
class Trainer[**P]:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
estimator: Estimator[P],
|
||||||
|
...
|
||||||
|
): ...
|
||||||
|
```
|
||||||
|
|
||||||
|
might be how we begin threading signatures. But ensuring dataset items can
|
||||||
|
match `P` is challenging. You can consider a "packed" object where we
|
||||||
|
obfuscate passing data through `P`-signatures:
|
||||||
|
|
||||||
|
```py
|
||||||
|
class PackedItem[**P]:
|
||||||
|
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
|
self._args = args
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
def apply[R](self, func: Callable[P, R]) -> R:
|
||||||
|
return func(*self._args, **self._kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedDataset[U, R, I, **P](Dataset):
|
||||||
|
@abstractmethod
|
||||||
|
def _process_item_data(
|
||||||
|
self,
|
||||||
|
item_data: I,
|
||||||
|
item_index: int,
|
||||||
|
) -> PackedItem[P]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[PackedItem[P]]:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Meaningfully shaping those signatures is what remains, but you can't really
|
||||||
|
do this, not with typical type expression flexibility. For instance, if I'm
|
||||||
|
trying to appropriately type my base `TupleDataset`:
|
||||||
|
|
||||||
|
```py
|
||||||
|
class SequenceDataset[I, **P](HomogenousDataset[int, I, I, P]):
|
||||||
|
...
|
||||||
|
|
||||||
|
class TupleDataset[I](SequenceDataset[tuple[I, ...], ??]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Here there's no way for me to shape a `ParamSpec` to indicate arbitrarily
|
||||||
|
many arguments of a fixed type (`I` in this case) to allow me to unpack my
|
||||||
|
item tuples into an appropriate `PackedItem`.
|
||||||
|
|
||||||
|
Until this (among other issues) becomes clearer, I'm setting up around a
|
||||||
|
simpler `TypedDict` type variable. We won't have particularly strong static
|
||||||
|
checks for item alignment inside `Trainer`, but this seems about as good as I
|
||||||
|
can get around the current infrastructure.
|
||||||
84
pyproject.toml
Normal file
84
pyproject.toml
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "trainlib"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
authors = [
|
||||||
|
{ name="Sam Griesemer", email="git@olog.io" },
|
||||||
|
]
|
||||||
|
readme = "README.md"
|
||||||
|
license = "MIT"
|
||||||
|
keywords = [
|
||||||
|
"machine-learning",
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: End Users/Desktop",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"colorama>=0.4.6",
|
||||||
|
"matplotlib>=3.10.8",
|
||||||
|
"numpy>=2.4.1",
|
||||||
|
"tensorboard>=2.20.0",
|
||||||
|
"torch>=2.5.1",
|
||||||
|
"tqdm>=4.67.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
symconf = "trainlib.__main__:main"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"ipykernel",
|
||||||
|
]
|
||||||
|
doc = [
|
||||||
|
"furo",
|
||||||
|
"myst-parser",
|
||||||
|
"sphinx",
|
||||||
|
"sphinx-togglebutton",
|
||||||
|
"sphinx-autodoc-typehints",
|
||||||
|
]
|
||||||
|
test = [
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://doc.olog.io/trainlib"
|
||||||
|
Documentation = "https://doc.olog.io/trainlib"
|
||||||
|
Repository = "https://git.olog.io/olog/trainlib"
|
||||||
|
Issues = "https://git.olog.io/olog/trainlib/issues"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
include = ["trainlib*"]
|
||||||
|
|
||||||
|
# for static data files under package root
|
||||||
|
# [tool.setuptools.package-data]
|
||||||
|
# "<package>" = ["data/*.toml"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 79
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["ANN", "E", "F", "UP", "B", "SIM", "I", "C4", "PERF"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
length-sort = true
|
||||||
|
order-by-type = false
|
||||||
|
force-sort-within-sections = false
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**" = ["S101"]
|
||||||
|
"**/__init__.py" = ["F401"]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
docstring-code-format = true
|
||||||
0
trainlib/__init__.py
Normal file
0
trainlib/__init__.py
Normal file
964
trainlib/dataset.py
Normal file
964
trainlib/dataset.py
Normal file
@@ -0,0 +1,964 @@
|
|||||||
|
"""
|
||||||
|
Marginalizing out the modality layer:
|
||||||
|
|
||||||
|
With ``domain`` being an instance variable, one possible interpretation of
|
||||||
|
the object structures here is that one could completely abstract away
|
||||||
|
the domain model, defining only item structures and processing data. You
|
||||||
|
could have a single dataset definition for a particular concrete dataset,
|
||||||
|
and so long as we're talking about the same items, it can be instantiated
|
||||||
|
using *any domain*. You wouldn't need specific subclasses for disk or
|
||||||
|
network or in-memory; you can tell it directly at runtime.
|
||||||
|
|
||||||
|
That's an eventually possibility, anyway. As it stands, however, this is
|
||||||
|
effectively impossible:
|
||||||
|
|
||||||
|
You can't easily abstract the batch -> item splitting process, i.e.,
|
||||||
|
``_process_batch_data()``. A list-based version of the dataset you're
|
||||||
|
trying to define might have an individual item tuple at every index,
|
||||||
|
whereas a disk-based version might have tuples batched across a few files.
|
||||||
|
This can't reliably be inferred, nor can it be pushed to the
|
||||||
|
``Domain``-level without needing equal levels of specialization (you'd just
|
||||||
|
end up needing the exact same structural distinctions in the ``Domain``
|
||||||
|
hierarchy). So *somewhere* you need a batch splitting implementation that
|
||||||
|
is both item structure-dependent *and* domain-dependent...the question is
|
||||||
|
how dynamic you're willing to be about where it comes from. Right now, we
|
||||||
|
require this actually be defined in the ``_process_batch_data()`` method,
|
||||||
|
meaning you'll need a specific ``Dataset`` class for each domain you want
|
||||||
|
to support (e.g., ``MNISTDisk``, ``MNISTList``, ``MNISTNetwork``, etc), or
|
||||||
|
at least for each domain where "interpreting" a batch could possibly
|
||||||
|
differ. This is a case where the interface is all that enforces a
|
||||||
|
distinction: if you've got two domains that can be counted on to yield
|
||||||
|
batches in the exact same way and can use the same processing, then you
|
||||||
|
could feasibly provide ``Domain`` objects from either at runtime and have
|
||||||
|
no issues. We're "structurally blind" to any differentiation beyond the URI
|
||||||
|
and resource types by design, so two different domain implementations with
|
||||||
|
the same type signature ``Domain[U, R]`` should be expected to work fine at
|
||||||
|
runtime (again, so long as they don't also need different batch
|
||||||
|
processing), but that's not affording us much flexibility, i.e., most of
|
||||||
|
the time we'll still be defining new dataset classes for each domain.
|
||||||
|
|
||||||
|
I initially flagged this as feasible, however, because one could imagine
|
||||||
|
accepting a batch processing method upon instantiation rather than
|
||||||
|
structurally bolting it into the ``Dataset`` definition. This would require
|
||||||
|
knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so
|
||||||
|
such a function will always have to be (I, U, R)-dependent. It nevertheless
|
||||||
|
would take out some of the pain of having to define new dataset classes;
|
||||||
|
instead, you'd just need to define the batch processing method. I see this
|
||||||
|
as a worse alternative to just defining *inside* a safe context like a new
|
||||||
|
dataset class: you know the types you have to respect, and you stick that
|
||||||
|
method exactly in a context where it's understood. Freeing this up doesn't
|
||||||
|
lighten the burden of processing logic, it just changes *when* it has to be
|
||||||
|
provided, and that's not worth much (to me) in this case given the bump in
|
||||||
|
complexity. (Taking this to the extreme: you could supply *all* of an
|
||||||
|
object's methods "dynamically" and glue them together at runtime so long as
|
||||||
|
they all played nice. But wherever you were "laying them out" beforehand is
|
||||||
|
exactly the job of a class to begin with, so you don't end up with anything
|
||||||
|
more dynamic. All we're really discussing here is pushing around
|
||||||
|
unavoidable complexity inside and outside of the "class walls," and in the
|
||||||
|
particular case of ``_process_batch_data()``, it feels much better when
|
||||||
|
it's on the inside.)
|
||||||
|
|
||||||
|
Holding:
|
||||||
|
@abstractmethod
|
||||||
|
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
|
||||||
|
Get URI groups for each batch.
|
||||||
|
|
||||||
|
If there's more than one URI per batch (e.g., a data file and a
|
||||||
|
metadata file), zip the URIs such that we have a tuple of URIs per
|
||||||
|
batch.
|
||||||
|
|
||||||
|
Note that this effectively defines the index style over batches in the
|
||||||
|
attached domain. We get an ``int -> tuple[U, ...]`` map that turns
|
||||||
|
batch indices into URIs that can be read under the domain.
|
||||||
|
``get_batch()`` turns an integer index into its corresponding
|
||||||
|
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
|
||||||
|
the tuple, treating them as providers of batched data.
|
||||||
|
``_read_resources()`` passes through to the attached domain logic,
|
||||||
|
which, although common, need not supply an explicit iterable of batch
|
||||||
|
items: we just access items with ``__getitem__()`` and may ask for
|
||||||
|
``__len__``. So the returned URI group collection (this method) does
|
||||||
|
need to be iterable to measure the number of batches, but the batch
|
||||||
|
objects that are ultimately produced by these URI groups need not be
|
||||||
|
iterables themselves.
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _read_resources(
|
||||||
|
self,
|
||||||
|
uri_group: tuple[U, ...],
|
||||||
|
batch_index: int
|
||||||
|
) -> tuple[R, ...]:
|
||||||
|
Read batch files at the provided paths.
|
||||||
|
|
||||||
|
This method should operate on a single tuple from the list of batch
|
||||||
|
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
|
||||||
|
all of the resources for a single batch and returns a tuple of the same
|
||||||
|
size with their contents.
|
||||||
|
|
||||||
|
Note: the dependence on a batch index is mostly here to make
|
||||||
|
multi-dataset composition easier later. In-dataset, you don't need to
|
||||||
|
know the batch index to to simply process URIs, but across datasets you
|
||||||
|
need it to find out the origin of the batch (and process those URIs
|
||||||
|
accordingly).
|
||||||
|
|
||||||
|
return tuple(self.domain.read(uri) for uri in uri_group)
|
||||||
|
|
||||||
|
# pulling the type variable out of the inline generic b/c `ty` has trouble
|
||||||
|
# understanding bound type variables in subclasses (specifically with Self@)
|
||||||
|
T = TypeVar("T", bound=NamedTuple)
|
||||||
|
|
||||||
|
class NamedTupleDataset[I](Dataset):
|
||||||
|
def __init__(self, data_list: list[I]) -> None:
|
||||||
|
self.data_list = data_list
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.data_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> I:
|
||||||
|
return self.data_list[index]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
from abc import abstractmethod
|
||||||
|
from copy import copy
|
||||||
|
from bisect import bisect
|
||||||
|
from typing import Unpack, TypedDict
|
||||||
|
from functools import lru_cache
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable, Iterator
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from trainlib.utils import job
|
||||||
|
from trainlib.domain import Domain, SequenceDomain
|
||||||
|
from trainlib.transform import Transform
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetKwargs[I](TypedDict, total=False):
|
||||||
|
pre_transform: Transform[I]
|
||||||
|
post_transform: Transform[I]
|
||||||
|
batch_cache_limit: int
|
||||||
|
preload: bool
|
||||||
|
num_workers: int
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedDataset[U, R, I](Dataset):
|
||||||
|
"""
|
||||||
|
Generic dataset that dynamically pulls batched data from resources (e.g.,
|
||||||
|
files, remote locations, streams, etc).
|
||||||
|
|
||||||
|
The class is generic over a URI type ``U``, a resource type ``R`` (both of
|
||||||
|
which are used to concretize a domain ``Domain[U, R]``), and an item type
|
||||||
|
``T`` (which has a ``tuple`` upper bound).
|
||||||
|
|
||||||
|
Pipeline overview:
|
||||||
|
|
||||||
|
```
|
||||||
|
Domain -> [U] (get _batch_uris)
|
||||||
|
U -> R (domain access ; Rs provide batches)
|
||||||
|
R -> [I] (cache here ; _process_batch_data to use load_transform)
|
||||||
|
[I] -> I (human item obj ; _get_item)
|
||||||
|
I -> **P (final packed item ; __getitem__ to use transform)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note^1: as far as positioning, this class is meant to play nice with
|
||||||
|
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
|
||||||
|
value add for this over the ``torch.Dataset`` base is almost entirely in
|
||||||
|
the logic it implements to map out of *batched resources* that are holding
|
||||||
|
data, and flattening it out into typical dataset items. There are also some
|
||||||
|
QoL items when it comes to splitting and balancing samples.
|
||||||
|
|
||||||
|
Note^2: even though ``Domains`` implement iterators over their URIs, this
|
||||||
|
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk
|
||||||
|
over the resources that provide data, but we don't necessarily presuppose
|
||||||
|
an ordered walk over samples within batches. Point being:
|
||||||
|
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
||||||
|
superclass, even when we're working around iterable ``Domains``.
|
||||||
|
|
||||||
|
Note^3: transforms are expected to operate on ``I``-items and produce
|
||||||
|
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
|
||||||
|
other intermediate representation, nor should they map from ``I`` to
|
||||||
|
something else. Point being: the dataset definition should be able to map
|
||||||
|
resources ``R`` to ``I`` without a transform: that much should be baked
|
||||||
|
into the class definition. If you find you're expecting the transform to do
|
||||||
|
that for you, you should consider pulling in some common structure across
|
||||||
|
the allowed transforms and make it a fixed part of the class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
domain: Domain[U, R],
|
||||||
|
pre_transform: Transform[I] | None = None,
|
||||||
|
post_transform: Transform[I] | None = None,
|
||||||
|
batch_cache_limit: int | None = None,
|
||||||
|
preload: bool = False,
|
||||||
|
num_workers: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
pre_transform: transform to apply over items during loading (in
|
||||||
|
``_process_batch_data()``), i.e., *before* going into
|
||||||
|
persistent storage
|
||||||
|
post_transform: transform to apply just prior to returning an item
|
||||||
|
(in ``_process_item_data()``), i.e., only *after* retrieval
|
||||||
|
from persistent storage
|
||||||
|
batch_cache_limit: the max number of max batches to cache at any
|
||||||
|
one time
|
||||||
|
preload: whether to load all data into memory during instantiation
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.domain = domain
|
||||||
|
self.pre_transform = pre_transform
|
||||||
|
self.post_transform = post_transform
|
||||||
|
self.batch_cache_limit = batch_cache_limit
|
||||||
|
self.num_workers = num_workers
|
||||||
|
|
||||||
|
logger.info("Fetching URIs...")
|
||||||
|
self._batch_uris: list[U] = list(domain)
|
||||||
|
|
||||||
|
self._indices: list[int] | None = None
|
||||||
|
self._dataset_len: int | None = None
|
||||||
|
self._num_batches: int = len(domain)
|
||||||
|
|
||||||
|
self.get_batch: Callable[[int], list[I]] = lru_cache(
|
||||||
|
maxsize=batch_cache_limit
|
||||||
|
)(self._get_batch)
|
||||||
|
|
||||||
|
if preload:
|
||||||
|
self.load_all(num_workers=num_workers)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_dataset_len(self) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the total dataset size in units of samples (not batches).
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Return the index of the batch containing the item at the provided item
|
||||||
|
index, and the index of the item within that batch.
|
||||||
|
|
||||||
|
The behavior of this method can vary depending on what we know about
|
||||||
|
batch sizes, and should therefore be implemented by inheriting classes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch_index: int
|
||||||
|
index_in_batch: int
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_batch_data(
|
||||||
|
self,
|
||||||
|
batch_data: R,
|
||||||
|
batch_index: int,
|
||||||
|
) -> list[I]:
|
||||||
|
"""
|
||||||
|
Process raw domain resource data (e.g., parse JSON or load tensors) and
|
||||||
|
split accordingly, such that the returned batch is a collection of
|
||||||
|
``I`` items.
|
||||||
|
|
||||||
|
If an inheriting class wants to allow dynamic transforms, this is the
|
||||||
|
place to use a provided ``pre_transform``; the collection of items
|
||||||
|
produced by this method are cached as a batch, so results from such a
|
||||||
|
transform will be stored.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
batch_data: tuple of resource data
|
||||||
|
batch_index: index of batch
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _process_item_data(
|
||||||
|
self,
|
||||||
|
item_data: I,
|
||||||
|
item_index: int,
|
||||||
|
) -> I:
|
||||||
|
"""
|
||||||
|
Process individual items and produce final tuples.
|
||||||
|
|
||||||
|
If an inheriting class wants to allow dynamic transforms, this is the
|
||||||
|
place to use a provided ``post_transform``; items are pulled from the
|
||||||
|
cache (if enabled) and processed before being returned as the final
|
||||||
|
tuple outputs (so this processing is not persistent).
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_item(self, item_index: int) -> I:
|
||||||
|
"""
|
||||||
|
Get the item data and zip with the item header.
|
||||||
|
|
||||||
|
Items should be the most granular representation of dataset samples
|
||||||
|
with maximum detail (i.e., yield a all available information). An
|
||||||
|
iterator over these representations across all samples can be retrieved
|
||||||
|
with `.items()`.
|
||||||
|
|
||||||
|
Note that return values from `__getitem__()` are "cleaned up" versions
|
||||||
|
of this representation, with minimal info needed for training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if item_index >= len(self):
|
||||||
|
raise IndexError
|
||||||
|
|
||||||
|
# alt indices redefine index count
|
||||||
|
item_index = self.indices[item_index]
|
||||||
|
batch_index, index_in_batch = self._get_batch_for_item(item_index)
|
||||||
|
|
||||||
|
return self.get_batch(batch_index)[index_in_batch]
|
||||||
|
|
||||||
|
def _get_batch(self, batch_index: int) -> list[I]:
|
||||||
|
"""
|
||||||
|
Return the batch data for the provided index.
|
||||||
|
|
||||||
|
Note that we require a list return type. This is where the rubber meets
|
||||||
|
the road in terms of expanding batches: the outputs here get cached, if
|
||||||
|
caching is enabled. If we were to defer batch expansion to some later
|
||||||
|
stage, that caching will be more or less worthless. For instance, if
|
||||||
|
``._process_batch_data()`` was an iterator, at best the batch
|
||||||
|
processing logic could be delayed, but then you'd either 1) cache the
|
||||||
|
iterator reference, or 2) have to further delay caching until
|
||||||
|
post-batch processing. Further, ``._process_batch_data()`` can (and
|
||||||
|
often does) depend on the entire batch, so you can't handle that
|
||||||
|
item-wise: you've got to pass all batch data in and can't act on
|
||||||
|
slices, so doing this would be irrelevant anyway.
|
||||||
|
|
||||||
|
How about iterators out from ``_read_resources()``? Then your
|
||||||
|
``_process_batch_data()`` can iterate as needed and do the work later?
|
||||||
|
The point here is that this distinction is irrelevant because you're
|
||||||
|
reading resources and processing the data here in the same method:
|
||||||
|
they're always connected, and nothing would notice if you waited
|
||||||
|
between steps. The only way this could matter is if you split the
|
||||||
|
resource reading and batch processing steps across methods, but when it
|
||||||
|
actually comes to accessing/caching the batch, you'd have to expand
|
||||||
|
any delayed reads here. There's no way around needing to see all batch
|
||||||
|
data at once here, and we don't want to make that ambiguous: ``list``
|
||||||
|
output type it is.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.debug("Batch cache miss, reading from root...")
|
||||||
|
|
||||||
|
if batch_index >= self._num_batches:
|
||||||
|
raise IndexError
|
||||||
|
|
||||||
|
batch_uri = self._batch_uris[batch_index]
|
||||||
|
batch_data = self.domain[batch_uri]
|
||||||
|
|
||||||
|
return self._process_batch_data(batch_data, batch_index)
|
||||||
|
|
||||||
|
def load_all(self, num_workers: int | None = None) -> list[list[I]]:
|
||||||
|
"""
|
||||||
|
Preload all data batches into the cache.
|
||||||
|
|
||||||
|
Can be useful when dynamically pulling data (as it's requested) isn't
|
||||||
|
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
|
||||||
|
continually remove previous batches as they're loaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self.batch_cache_limit is None, "Preloading under cache limit"
|
||||||
|
|
||||||
|
if num_workers is None:
|
||||||
|
num_workers = self.num_workers
|
||||||
|
|
||||||
|
thread_pool = ThreadPoolExecutor(max_workers=num_workers)
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
for batch_index in range(self._num_batches):
|
||||||
|
future = thread_pool.submit(
|
||||||
|
self.get_batch,
|
||||||
|
batch_index,
|
||||||
|
)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
job.process_futures(futures, "Loading dataset batches", "batch")
|
||||||
|
thread_pool.shutdown(wait=True)
|
||||||
|
|
||||||
|
return [future.result() for future in futures]
|
||||||
|
|
||||||
|
def split(
|
||||||
|
self,
|
||||||
|
fracs: list[float],
|
||||||
|
dataset: "BatchedDataset | None" = None,
|
||||||
|
by_attr: str | list[str | None] | None = None,
|
||||||
|
shuffle_strata: bool = True,
|
||||||
|
) -> list["BatchedDataset"]:
|
||||||
|
"""
|
||||||
|
Split dataset into fractional pieces by data attribute.
|
||||||
|
|
||||||
|
If `by_attr` is None, recovers typical fractional splitting of dataset
|
||||||
|
items, partitioning by size. Using None anywhere will index each item
|
||||||
|
into its own bucket, i.e., by its index. For instance,
|
||||||
|
|
||||||
|
- by_attr=["color"] -> {("red", 1), ("red", 2)},
|
||||||
|
{("blue", 1), ("blue", 2)}
|
||||||
|
|
||||||
|
Splits on the attribute such that each subset contains entire strata
|
||||||
|
of the attribute. "Homogeneity within clusters"
|
||||||
|
|
||||||
|
- `by_attr=["color", None]` -> {("red", 1), ("blue", 1)},
|
||||||
|
{("red", 2), ("blue", 2)}
|
||||||
|
|
||||||
|
Stratifies by attribute and then splits "by index" within, uniformly
|
||||||
|
grabbing samples across strata to form new clusters. "Homogeneity
|
||||||
|
across clusters"
|
||||||
|
|
||||||
|
Note that the final list of Subsets returned are built from shallow
|
||||||
|
copies of the underlying dataset (i.e., `self`) to allow manual
|
||||||
|
intervention with dataset attributes (e.g., setting the splits to have
|
||||||
|
different `transform`s). This is subject to possibly unexpected
|
||||||
|
behavior if re-caching data or you need a true copy of all data in
|
||||||
|
memory, but should otherwise leave most interactions unchanged.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
shuffle_strata: shuffle the strata order before split is drawn. We
|
||||||
|
parameterize this because a dataloader-level shuffle operation
|
||||||
|
will only change the order of the indices in the resulting
|
||||||
|
splits; only a shuffle of items inside the strata can change
|
||||||
|
the actual content of the splits themselves.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if by_attr == []:
|
||||||
|
raise ValueError("Cannot parse empty value list")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
math.isclose(sum(fracs), 1) and sum(fracs) <= 1
|
||||||
|
), "Fractions do not sum to 1"
|
||||||
|
|
||||||
|
if isinstance(by_attr, str) or by_attr is None:
|
||||||
|
by_attr = [by_attr]
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
dataset = self
|
||||||
|
# dataset = DictDataset([
|
||||||
|
# self._get_item(i) for i in range(len(self))
|
||||||
|
# ])
|
||||||
|
|
||||||
|
# group samples by specified attr
|
||||||
|
attr_dict = defaultdict(list)
|
||||||
|
attr_key, by_attr = by_attr[0], by_attr[1:]
|
||||||
|
# for i in range(len(dataset)):
|
||||||
|
for i, item in enumerate(dataset.items()):
|
||||||
|
# item = dataset[i]
|
||||||
|
if attr_key is None:
|
||||||
|
attr_val = i
|
||||||
|
elif attr_key in item:
|
||||||
|
attr_val = item[attr_key]
|
||||||
|
else:
|
||||||
|
raise IndexError(f"Attribute {attr_key} not in dataset item")
|
||||||
|
attr_dict[attr_val].append(i)
|
||||||
|
|
||||||
|
if by_attr == []:
|
||||||
|
attr_keys = list(attr_dict.keys())
|
||||||
|
|
||||||
|
# shuffle keys; randomized group-level split
|
||||||
|
if shuffle_strata:
|
||||||
|
random.shuffle(attr_keys)
|
||||||
|
|
||||||
|
# considering: defer to dataloader shuffle param; should have same
|
||||||
|
# effect shuffle values; has no impact on where the split is drawn
|
||||||
|
# for attr_vals in attr_dict.values():
|
||||||
|
# random.shuffle(attr_vals)
|
||||||
|
|
||||||
|
# fractionally split over attribute keys
|
||||||
|
offset, splits = 0, []
|
||||||
|
for frac in fracs[:-1]:
|
||||||
|
frac_indices = []
|
||||||
|
frac_size = int(frac * len(attr_keys))
|
||||||
|
for j in range(offset, offset + frac_size):
|
||||||
|
j_indices = attr_dict.pop(attr_keys[j])
|
||||||
|
frac_indices.extend(j_indices)
|
||||||
|
offset += frac_size
|
||||||
|
splits.append(frac_indices)
|
||||||
|
|
||||||
|
rem_indices = []
|
||||||
|
for r_indices in attr_dict.values():
|
||||||
|
rem_indices.extend(r_indices)
|
||||||
|
splits.append(rem_indices)
|
||||||
|
|
||||||
|
subsets = []
|
||||||
|
for split in splits:
|
||||||
|
subset = copy(dataset)
|
||||||
|
subset.indices = split
|
||||||
|
subsets.append(subset)
|
||||||
|
|
||||||
|
return subsets
|
||||||
|
else:
|
||||||
|
splits = [[] for _ in range(len(fracs))]
|
||||||
|
for index_split in attr_dict.values():
|
||||||
|
# subset = Subset(dataset, index_split)
|
||||||
|
subset = copy(dataset)
|
||||||
|
subset.indices = index_split
|
||||||
|
subset_splits = self.split(
|
||||||
|
fracs, subset, by_attr, shuffle_strata
|
||||||
|
)
|
||||||
|
|
||||||
|
# unpack stratified splits
|
||||||
|
for i, subset_split in enumerate(subset_splits):
|
||||||
|
# splits[i].extend([
|
||||||
|
# index_split[s] for s in subset_split.indices
|
||||||
|
# ])
|
||||||
|
splits[i].extend(subset_split.indices)
|
||||||
|
|
||||||
|
subsets = []
|
||||||
|
for split in splits:
|
||||||
|
subset = copy(dataset)
|
||||||
|
subset.reset_indices()
|
||||||
|
subset.indices = split
|
||||||
|
subsets.append(subset)
|
||||||
|
|
||||||
|
return subsets
|
||||||
|
|
||||||
|
# considering: defer to dataloader shuffle param; should have same
|
||||||
|
# effect shuffle each split after merging; may otherwise be homogenous
|
||||||
|
# for split in splits:
|
||||||
|
# random.shuffle(split)
|
||||||
|
|
||||||
|
# return [Subset(copy(self), split) for split in splits]
|
||||||
|
|
||||||
|
def balance(
|
||||||
|
self,
|
||||||
|
dataset: "BatchedSubset[U, R, I] | None" = None,
|
||||||
|
by_attr: str | list[str | None] | None = None,
|
||||||
|
split_min_sizes: list[int] | None = None,
|
||||||
|
split_max_sizes: list[int] | None = None,
|
||||||
|
shuffle_strata: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.indices = self._balance(
|
||||||
|
dataset,
|
||||||
|
by_attr,
|
||||||
|
split_min_sizes,
|
||||||
|
split_max_sizes,
|
||||||
|
shuffle_strata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _balance(
|
||||||
|
self,
|
||||||
|
dataset: "BatchedSubset[U, R, I] | None" = None,
|
||||||
|
by_attr: str | list[str | None] | None = None,
|
||||||
|
split_min_sizes: list[int] | None = None,
|
||||||
|
split_max_sizes: list[int] | None = None,
|
||||||
|
shuffle_strata: bool = True,
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Note: behavior is a little odd for nested behavior; not exactly
|
||||||
|
perfectly uniform throughout. This is a little difficult: you can't
|
||||||
|
exactly know ahead of time the size of the subgroups across splits
|
||||||
|
"""
|
||||||
|
|
||||||
|
if by_attr == []:
|
||||||
|
raise ValueError("Cannot parse empty value list")
|
||||||
|
|
||||||
|
if isinstance(by_attr, str) or by_attr is None:
|
||||||
|
by_attr = [by_attr]
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
dataset = BatchedSubset(self, self.indices)
|
||||||
|
|
||||||
|
if split_min_sizes == [] or split_min_sizes is None:
|
||||||
|
split_min_sizes = [0]
|
||||||
|
if split_max_sizes == [] or split_max_sizes is None:
|
||||||
|
split_max_sizes = [len(dataset)]
|
||||||
|
|
||||||
|
# group samples by specified attr
|
||||||
|
attr_dict = defaultdict(list)
|
||||||
|
attr_key, by_attr = by_attr[0], by_attr[1:]
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
if attr_key is None:
|
||||||
|
attr_val = i
|
||||||
|
elif attr_key in item:
|
||||||
|
attr_val = getattr(item, attr_key)
|
||||||
|
else:
|
||||||
|
raise IndexError(f"Attribute {attr_key} not in dataset item")
|
||||||
|
attr_dict[attr_val].append(i)
|
||||||
|
|
||||||
|
subset_splits = []
|
||||||
|
split_min_size, split_min_sizes = (
|
||||||
|
split_min_sizes[0],
|
||||||
|
split_min_sizes[1:]
|
||||||
|
)
|
||||||
|
split_max_size, split_max_sizes = (
|
||||||
|
split_max_sizes[0],
|
||||||
|
split_max_sizes[1:]
|
||||||
|
)
|
||||||
|
for split_indices in attr_dict.values():
|
||||||
|
if by_attr != []:
|
||||||
|
subset_indices = self._balance(
|
||||||
|
BatchedSubset(dataset, split_indices),
|
||||||
|
by_attr,
|
||||||
|
split_min_sizes,
|
||||||
|
split_max_sizes,
|
||||||
|
shuffle_strata,
|
||||||
|
)
|
||||||
|
split_indices = [split_indices[s] for s in subset_indices]
|
||||||
|
subset_splits.append(split_indices)
|
||||||
|
|
||||||
|
# shuffle splits; randomized group-level split
|
||||||
|
if shuffle_strata:
|
||||||
|
random.shuffle(subset_splits)
|
||||||
|
|
||||||
|
# note: split_min_size is smallest allowed, min_split_size is smallest
|
||||||
|
# observed
|
||||||
|
valid_splits = [
|
||||||
|
ss for ss in subset_splits
|
||||||
|
if len(ss) >= split_min_size
|
||||||
|
]
|
||||||
|
min_split_size = min(len(split) for split in valid_splits)
|
||||||
|
|
||||||
|
subset_indices = []
|
||||||
|
for split_indices in valid_splits:
|
||||||
|
# if shuffle_strata:
|
||||||
|
# random.shuffle(split_indices)
|
||||||
|
subset_indices.extend(
|
||||||
|
split_indices[: min(min_split_size, split_max_size)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(f"{attr_dict.keys()=}")
|
||||||
|
# print(f"{[len(s) for s in valid_splits]=}")
|
||||||
|
# print(f"{min_split_size=}")
|
||||||
|
|
||||||
|
return subset_indices
|
||||||
|
|
||||||
|
@property
|
||||||
|
def indices(self) -> list[int]:
|
||||||
|
if self._indices is None:
|
||||||
|
self._indices = list(range(len(self)))
|
||||||
|
return self._indices
|
||||||
|
|
||||||
|
@indices.setter
|
||||||
|
def indices(self, indices: list[int]) -> None:
|
||||||
|
"""
|
||||||
|
Note: this logic facilitates nested re-indexing over the same base
|
||||||
|
dataset. The underlying data remain the same, but when indices get set,
|
||||||
|
you're effectively applying a mask over any existing indices, always
|
||||||
|
operating *relative* to the existing mask.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# manually set new size
|
||||||
|
self._dataset_len = len(indices)
|
||||||
|
|
||||||
|
# note: this is a little tricky and compact; follow what happens when
|
||||||
|
# _indices aren't already set
|
||||||
|
self._indices = [
|
||||||
|
self.indices[index]
|
||||||
|
for index in indices
|
||||||
|
]
|
||||||
|
|
||||||
|
def reset_indices(self) -> None:
|
||||||
|
self._indices = None
|
||||||
|
self._dataset_len = None
|
||||||
|
|
||||||
|
def items(self) -> Iterator[I]:
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield self._get_item(i)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self._dataset_len is None:
|
||||||
|
self._dataset_len = self._get_dataset_len()
|
||||||
|
|
||||||
|
return self._dataset_len
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> I:
|
||||||
|
item_data = self._get_item(index)
|
||||||
|
index = self.indices[index]
|
||||||
|
|
||||||
|
return self._process_item_data(item_data, index)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[I]:
|
||||||
|
"""
|
||||||
|
Note: this method isn't technically needed given ``__getitem__`` is
|
||||||
|
defined and we operate cleanly over integer indices 0..(N-1), so even
|
||||||
|
without an explicit ``__iter__``, Python will fall back to a reliable
|
||||||
|
iteration mechanism. We nevertheless implement the trivial logic below
|
||||||
|
to convey intent and meet static type checks for iterables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (self[i] for i in range(len(self)))
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeBatchedDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||||
|
"""
|
||||||
|
Dataset class for wrapping individual datasets.
|
||||||
|
|
||||||
|
Note: because this remains a valid ``BatchedDataset``, we re-thread the
|
||||||
|
generic type variables through the set of composed datasets. That is, they
|
||||||
|
must have a common domain type ``Domain[U, R]``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
datasets: list[BatchedDataset[U, R, I]],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
datasets: list of datasets
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.datasets = datasets
|
||||||
|
|
||||||
|
self._indices: list[int] | None = None
|
||||||
|
self._dataset_len: int | None = None
|
||||||
|
|
||||||
|
self._item_psum: list[int] = []
|
||||||
|
self._batch_psum: list[int] = []
|
||||||
|
|
||||||
|
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
|
||||||
|
if not arr:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prefix_sum = [0] * (len(arr) + 1)
|
||||||
|
for i in range(len(arr)):
|
||||||
|
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
|
||||||
|
|
||||||
|
return prefix_sum
|
||||||
|
|
||||||
|
def _get_dataset_for_item(self, item_index: int) -> tuple[int, int]:
|
||||||
|
dataset_index = bisect(self._item_psum, item_index) - 1
|
||||||
|
index_in_dataset = item_index - self._item_psum[dataset_index]
|
||||||
|
|
||||||
|
return dataset_index, index_in_dataset
|
||||||
|
|
||||||
|
def _get_dataset_for_batch(self, batch_index: int) -> tuple[int, int]:
|
||||||
|
dataset_index = bisect(self._batch_psum, batch_index) - 1
|
||||||
|
index_in_dataset = batch_index - self._batch_psum[dataset_index]
|
||||||
|
|
||||||
|
return dataset_index, index_in_dataset
|
||||||
|
|
||||||
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||||
|
index_pair = self._get_dataset_for_item(item_index)
|
||||||
|
dataset_index, index_in_dataset = index_pair
|
||||||
|
dataset = self.datasets[dataset_index]
|
||||||
|
|
||||||
|
return dataset._get_batch_for_item(index_in_dataset)
|
||||||
|
|
||||||
|
def _get_dataset_len(self) -> int:
|
||||||
|
self.load_all()
|
||||||
|
|
||||||
|
dataset_batch_sizes = [len(dataset) for dataset in self.datasets]
|
||||||
|
dataset_sizes = [dataset._num_batches for dataset in self.datasets]
|
||||||
|
|
||||||
|
# this method will only be ran once; set this instance var
|
||||||
|
self._item_psum = self._compute_prefix_sum(dataset_batch_sizes)
|
||||||
|
self._batch_psum = self._compute_prefix_sum(dataset_sizes)
|
||||||
|
|
||||||
|
return self._item_psum[-1]
|
||||||
|
|
||||||
|
def _process_batch_data(
|
||||||
|
self,
|
||||||
|
batch_data: R,
|
||||||
|
batch_index: int,
|
||||||
|
) -> list[I]:
|
||||||
|
index_pair = self._get_dataset_for_item(batch_index)
|
||||||
|
dataset_index, index_in_dataset = index_pair
|
||||||
|
dataset = self.datasets[dataset_index]
|
||||||
|
|
||||||
|
return dataset._process_batch_data(batch_data, index_in_dataset)
|
||||||
|
|
||||||
|
def _process_item_data(self, item_data: I, item_index: int) -> I:
|
||||||
|
index_pair = self._get_dataset_for_item(item_index)
|
||||||
|
dataset_index, index_in_dataset = index_pair
|
||||||
|
dataset = self.datasets[dataset_index]
|
||||||
|
|
||||||
|
return dataset._process_item_data(item_data, index_in_dataset)
|
||||||
|
|
||||||
|
def _get_item(self, item_index: int) -> I:
|
||||||
|
item_index = self.indices[item_index]
|
||||||
|
|
||||||
|
index_pair = self._get_dataset_for_item(item_index)
|
||||||
|
dataset_index, index_in_dataset = index_pair
|
||||||
|
dataset = self.datasets[dataset_index]
|
||||||
|
|
||||||
|
return dataset._get_item(index_in_dataset)
|
||||||
|
|
||||||
|
def _get_batch(self, batch_index: int) -> list[I]:
|
||||||
|
index_pair = self._get_dataset_for_item(batch_index)
|
||||||
|
dataset_index, index_in_dataset = index_pair
|
||||||
|
dataset = self.datasets[dataset_index]
|
||||||
|
|
||||||
|
return dataset._get_batch(index_in_dataset)
|
||||||
|
|
||||||
|
def load_all(
|
||||||
|
self,
|
||||||
|
num_workers: int | None = None
|
||||||
|
) -> list[list[I]]:
|
||||||
|
batches = []
|
||||||
|
for dataset in self.datasets:
|
||||||
|
batches.extend(dataset.load_all(num_workers))
|
||||||
|
return batches
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pre_transform(self) -> list[Transform[I] | None]:
|
||||||
|
return [dataset.pre_transform for dataset in self.datasets]
|
||||||
|
|
||||||
|
@pre_transform.setter
|
||||||
|
def pre_transform(self, pre_transform: Transform[I]) -> None:
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.pre_transform = pre_transform
|
||||||
|
|
||||||
|
@property
|
||||||
|
def post_transform(self) -> list[Transform[I] | None]:
|
||||||
|
return [dataset.post_transform for dataset in self.datasets]
|
||||||
|
|
||||||
|
@post_transform.setter
|
||||||
|
def post_transform(self, post_transform: Transform[I]) -> None:
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.post_transform = post_transform
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedSubset[U, R, I](BatchedDataset[U, R, I]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: BatchedDataset[U, R, I],
|
||||||
|
indices: list[int],
|
||||||
|
) -> None:
|
||||||
|
self.dataset = dataset
|
||||||
|
self.indices = indices
|
||||||
|
|
||||||
|
def _get_item(self, item_index: int) -> I:
|
||||||
|
"""
|
||||||
|
Subset indices are "reset" in its context. Simply passes through
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.dataset._get_item(self.indices[item_index])
|
||||||
|
|
||||||
|
def _get_dataset_len(self) -> int:
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
|
class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||||
|
"""
|
||||||
|
Batched dataset where batches are equally sized.
|
||||||
|
|
||||||
|
Subclass from this base when you can count on the reference data being
|
||||||
|
prepared with fixed size batches (up to the last batch), e.g., that which
|
||||||
|
has been prepared with a `Packer`. This can greatly improve measurement
|
||||||
|
time of the dataset size by preventing the need for reading all batch files
|
||||||
|
upfront, and reduces the cost of identifying item batches from O(log n) to
|
||||||
|
O(1).
|
||||||
|
|
||||||
|
Methods left for inheriting classes:
|
||||||
|
|
||||||
|
- ``_process_item_data()``: item processing
|
||||||
|
- ``_process_batch_data()``: batch processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
domain: Domain[U, R],
|
||||||
|
**kwargs: Unpack[DatasetKwargs],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(domain, **kwargs)
|
||||||
|
|
||||||
|
# determine batch size across dataset, along w/ possible partial final
|
||||||
|
# batch
|
||||||
|
bsize = rem = len(self.get_batch(self._num_batches - 1))
|
||||||
|
if self._num_batches > 1:
|
||||||
|
bsize, rem = len(self.get_batch(self._num_batches - 2)), bsize
|
||||||
|
|
||||||
|
self._batch_size: int = bsize
|
||||||
|
self._batch_rem: int = rem
|
||||||
|
|
||||||
|
def _get_dataset_len(self) -> int:
|
||||||
|
return self._batch_size * (self._num_batches - 1) + self._batch_rem
|
||||||
|
|
||||||
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||||
|
return item_index // self._batch_size, item_index % self._batch_size
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||||
|
"""
|
||||||
|
Batched dataset where batches have arbitrary size.
|
||||||
|
|
||||||
|
Methods left for inheriting classes:
|
||||||
|
|
||||||
|
- ``_process_item_data()``: item processing
|
||||||
|
- ``_process_batch_data()``: batch processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
domain: Domain[U, R],
|
||||||
|
**kwargs: Unpack[DatasetKwargs],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(domain, **kwargs)
|
||||||
|
|
||||||
|
self._batch_size_psum: list[int] = []
|
||||||
|
|
||||||
|
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
|
||||||
|
if not arr:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prefix_sum = [0] * (len(arr) + 1)
|
||||||
|
for i in range(len(arr)):
|
||||||
|
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
|
||||||
|
|
||||||
|
return prefix_sum
|
||||||
|
|
||||||
|
def _get_dataset_len(self) -> int:
|
||||||
|
# type error below: no idea why this is flagged
|
||||||
|
batches = self.load_all()
|
||||||
|
batch_sizes = [len(batch) for batch in batches]
|
||||||
|
|
||||||
|
# this method will only be ran once; set this instance var
|
||||||
|
self._batch_size_psum = self._compute_prefix_sum(batch_sizes)
|
||||||
|
|
||||||
|
return self._batch_size_psum[-1]
|
||||||
|
|
||||||
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||||
|
batch_index = bisect(self._batch_size_psum, item_index) - 1
|
||||||
|
index_in_batch = item_index - self._batch_size_psum[batch_index]
|
||||||
|
|
||||||
|
return batch_index, index_in_batch
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceDataset[I](HomogenousDataset[int, I, I]):
|
||||||
|
"""
|
||||||
|
Trivial dataset skeleton for sequence domains.
|
||||||
|
|
||||||
|
``I``-typed sequence items map directly to dataset items. To produce a
|
||||||
|
fully concrete dataset, one still needs to define ``_process_item_data()``
|
||||||
|
to map from ``I``-items to tuples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
domain: SequenceDomain[I]
|
||||||
|
|
||||||
|
def _process_batch_data(
|
||||||
|
self,
|
||||||
|
batch_data: I,
|
||||||
|
batch_index: int,
|
||||||
|
) -> list[I]:
|
||||||
|
if self.pre_transform is not None:
|
||||||
|
batch_data = self.pre_transform(batch_data)
|
||||||
|
|
||||||
|
return [batch_data]
|
||||||
|
|
||||||
|
|
||||||
|
class TupleDataset[T](SequenceDataset[tuple[T, ...]]):
|
||||||
|
"""
|
||||||
|
Trivial sequence-of-tuples dataset.
|
||||||
|
|
||||||
|
This is the most straightforward line to a concrete dataset from a
|
||||||
|
``BatchedDataset`` base class. That is: the underlying domain is a sequence
|
||||||
|
whose items are mapped to single-item batches and are already tuples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _process_item_data(
|
||||||
|
self,
|
||||||
|
item_data: tuple[T, ...],
|
||||||
|
item_index: int,
|
||||||
|
) -> tuple[T, ...]:
|
||||||
|
if self.post_transform is not None:
|
||||||
|
item_data = self.post_transform(item_data)
|
||||||
|
|
||||||
|
return item_data
|
||||||
0
trainlib/datasets/__init__.py
Normal file
0
trainlib/datasets/__init__.py
Normal file
179
trainlib/datasets/disk.py
Normal file
179
trainlib/datasets/disk.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
from pathlib import Path
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
from mema.dataset import HomogenousDataset
|
||||||
|
from mema.domains.disk import DiskDomain
|
||||||
|
|
||||||
|
|
||||||
|
class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
|
||||||
|
"""
|
||||||
|
The following line is to satisfy the type checker, which
|
||||||
|
|
||||||
|
1. Can't recognize an appropriately re-typed constructor arg like
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
domain: DiskDomain,
|
||||||
|
...
|
||||||
|
): ...
|
||||||
|
|
||||||
|
This *does* match the parent generic for the U=Path, R=bytes context
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
domain: Domain[U, R],
|
||||||
|
...
|
||||||
|
): ...
|
||||||
|
|
||||||
|
but the type checker doesn't see this.
|
||||||
|
|
||||||
|
2. "Lifted" type variables out of generics can't be used as upper bounds,
|
||||||
|
at least not without throwing type checker warnings (thanks to PEP695).
|
||||||
|
So I'm not allowed to have
|
||||||
|
|
||||||
|
```
|
||||||
|
class BatchedDataset[U, R, D: Domain[U, R]]:
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
which could bring appropriately dynamic typing for ``Domain``s, but is
|
||||||
|
not a sufficiently concrete upper bound.
|
||||||
|
|
||||||
|
So: we settle for a class-level type declaration, which despite not being
|
||||||
|
technically appropriately scoped, it's not harming anything and satisfies
|
||||||
|
``ty`` type checks downstream (e.g., when we access ``DiskDomain.root``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
domain: DiskDomain
|
||||||
|
|
||||||
|
|
||||||
|
class PackedDataset(DiskDataset):
|
||||||
|
"""
|
||||||
|
Packed dataset.
|
||||||
|
|
||||||
|
Currently out of commission - not compatible with latest dataset
|
||||||
|
definitions. Will require a zipped disk domain
|
||||||
|
|
||||||
|
Requires a specific dataset storage structure on the root data path:
|
||||||
|
|
||||||
|
<data-path>/data/*-i<batch-num>-b<batch-size>
|
||||||
|
<data-path>/meta/*-i<batch-num>-b<batch-size>
|
||||||
|
|
||||||
|
That is, all data are compacted into core data (`data/`) and metadata
|
||||||
|
(`meta/`) subdirectories. Compatible out-of-the-box with datasets written
|
||||||
|
with a `Packer`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_uri_groups(self) -> list[tuple[Path, ...]]:
|
||||||
|
data_root = Path(self.domain.root, "data")
|
||||||
|
meta_root = Path(self.domain.root, "meta")
|
||||||
|
|
||||||
|
data_file_paths = data_root.iterdir()
|
||||||
|
meta_file_paths = meta_root.iterdir()
|
||||||
|
|
||||||
|
return list(zip(data_file_paths, meta_file_paths, strict=True))
|
||||||
|
|
||||||
|
def _process_batch_data(
|
||||||
|
self,
|
||||||
|
batch_data: tuple[bytes, ...],
|
||||||
|
batch_index: int,
|
||||||
|
) -> list[tuple[bytes, ...]]:
|
||||||
|
data_bytes, meta_bytes = batch_data
|
||||||
|
|
||||||
|
meta_batch = self._unpack_meta(meta_bytes)
|
||||||
|
data_batch = self._unpack_data(data_bytes, meta_batch)
|
||||||
|
|
||||||
|
# zip up batch partial batch items into a single batch iterable
|
||||||
|
# composed of item tuples
|
||||||
|
batch_items = [
|
||||||
|
(*ba, *bm) # pyre-ignore[60]
|
||||||
|
for ba, bm in zip(data_batch, meta_batch, strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
# apply transform to batch items if provided
|
||||||
|
if self.load_transform:
|
||||||
|
batch_items = list(map(self.load_transform, batch_items))
|
||||||
|
|
||||||
|
return batch_items
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _unpack_data(
|
||||||
|
self,
|
||||||
|
batch_data_bytes: bytes,
|
||||||
|
batch_meta: list[tuple[Any, ...]],
|
||||||
|
) -> list[tuple[Any, ...]]:
|
||||||
|
"""
|
||||||
|
Load and unpack batch data.
|
||||||
|
|
||||||
|
This method should be the inverse of an affiliated
|
||||||
|
`Packer`'s `pack_data_bytes()`:
|
||||||
|
|
||||||
|
<data list> -> pack -> to bytes -> blob
|
||||||
|
<data list> <- unpack <- from bytes <- blob
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
iterable of (partial) item tuples w/ data content
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _unpack_meta(self, batch_meta_bytes: bytes) -> list[tuple[Any, ...]]:
|
||||||
|
"""
|
||||||
|
Load and unpack batch metadata.
|
||||||
|
|
||||||
|
This method should be the inverse of an affiliated
|
||||||
|
`Packer`'s `pack_meta_bytes()`:
|
||||||
|
|
||||||
|
<data list> -> pack -> to bytes -> blob
|
||||||
|
<data list> <- unpack <- from bytes <- blob
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
iterable of (partial) item tuples w/ meta content
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ZippedDataset(DiskDataset):
|
||||||
|
"""
|
||||||
|
Dataset with samples stored in ZIP files.
|
||||||
|
|
||||||
|
This dataset base is primarily used as the type of input dataset for a
|
||||||
|
`Packer` object. This is compatible with most raw dataset structures,
|
||||||
|
reading down arbitrarily packaged ZIP files and "re-batching" during
|
||||||
|
access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
item_header: tuple[str, ...] = ("bytes",)
|
||||||
|
|
||||||
|
def _get_uri_groups(self) -> list[tuple[str, ...]]:
|
||||||
|
zip_file_paths: list[str] = [
|
||||||
|
str(path)
|
||||||
|
for path in Path(self.data_path).iterdir()
|
||||||
|
if path.suffix == ".zip"
|
||||||
|
]
|
||||||
|
|
||||||
|
# will just zip a single list yielding 1-tuples (to match type sig)
|
||||||
|
return list(zip(zip_file_paths))
|
||||||
|
|
||||||
|
def _process_batch_data(
|
||||||
|
self,
|
||||||
|
batch_data: tuple[bytes, ...],
|
||||||
|
batch_index: int,
|
||||||
|
) -> list[tuple[bytes, ...]]:
|
||||||
|
items = []
|
||||||
|
batch_zip_file = ZipFile(BytesIO(batch_data[0]))
|
||||||
|
|
||||||
|
for zname in batch_zip_file.namelist():
|
||||||
|
if Path(zname).suffix not in self.extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
with batch_zip_file.open(zname, "r") as zfile:
|
||||||
|
items.append((zfile.read(), zname))
|
||||||
|
|
||||||
|
return list(items)
|
||||||
|
|
||||||
210
trainlib/datasets/memory.py
Normal file
210
trainlib/datasets/memory.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""
|
||||||
|
Leaving the following here in case we return for some specifics in the future.
|
||||||
|
At an earlier stage, we had a few specifically typed datasets/domains for
|
||||||
|
in-memory structures, but these are almost entirely redundant with the generic
|
||||||
|
``SequenceDomain`` definition: the general retrieval and iteration behaviors
|
||||||
|
there are fairly universal and can be type-tailored without new class
|
||||||
|
definitions (unless you really want a new hierarchy).
|
||||||
|
|
||||||
|
The following were design notes on a dict/named tuple list-based dataset.
|
||||||
|
|
||||||
|
Dataset from list of (dict) records.
|
||||||
|
|
||||||
|
This is designed such that a "batch" is just a single record in the base
|
||||||
|
list. Each URI group is a singleton index tuple, which grabs its single
|
||||||
|
corresponding record in the domain when read.
|
||||||
|
|
||||||
|
One could alternatively have the entire record list be a single batch, and
|
||||||
|
have a single URI group with N null references. The null reference would
|
||||||
|
need to be the sole URI for the domain definition, and ``.read(null)``
|
||||||
|
would always return the entire list. Everything will then be handled as
|
||||||
|
expected: in ``get_item()``,
|
||||||
|
|
||||||
|
- ``._get_batch_for_item(item_index)`` maps to the singleton batch
|
||||||
|
- ``.get_batch(batch_index)[index_in_batch]`` simply indexes directly in
|
||||||
|
the record list
|
||||||
|
|
||||||
|
This is pretty unnatural, since now a batch as returned by
|
||||||
|
``.read_resources()`` will be N references to the entire record list. It
|
||||||
|
nevertheless sidesteps full list reconstruction and allows the propagation
|
||||||
|
of direct indexing, so it has that in its corner.
|
||||||
|
|
||||||
|
Another viable approach (least preferred): have a single batch, but where
|
||||||
|
URIs map to individual row indices. In some respects, this is the most
|
||||||
|
intuitive interpretation of "batch" and how we'd map to items, but it's the
|
||||||
|
*least efficient* because the batch logic of ``BatchedDataset`` will
|
||||||
|
*reconstruct the entire record list*. In ``.read_resources()``, we have
|
||||||
|
|
||||||
|
```
|
||||||
|
return tuple(self.domain.read(uri) for uri in uri_group)
|
||||||
|
```
|
||||||
|
|
||||||
|
So, despite being a natural interpretation, it pulls apart the domain "too
|
||||||
|
well" and rebuilds it in-house. This makes sense when resources are
|
||||||
|
external (e.g., files on disk, data over network), but when already
|
||||||
|
part of the Python process, that model is wasteful.
|
||||||
|
|
||||||
|
Note: inheriting datasets will need to implement an appropriate
|
||||||
|
``_item_header``. This will be trivial for consistent datasets, since it'll
|
||||||
|
just be the keys of any of the dict records.
|
||||||
|
|
||||||
|
Left to define:
|
||||||
|
|
||||||
|
+ ``_process_item_data()``: map T to final tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
|
||||||
|
|
||||||
|
domain: TupleListDomain[T]
|
||||||
|
|
||||||
|
def _process_batch_data(
|
||||||
|
self,
|
||||||
|
batch_data: T,
|
||||||
|
batch_index: int,
|
||||||
|
) -> list[T]:
|
||||||
|
Produce collection of item tuples.
|
||||||
|
|
||||||
|
Note: no interaction with ``item_tuple`` is needed given batches are
|
||||||
|
already singular ``item_tuple`` shaped items.
|
||||||
|
|
||||||
|
return [batch_data]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Unpack
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from trainlib.domain import SequenceDomain
|
||||||
|
from trainlib.dataset import TupleDataset, DatasetKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class SlidingWindowDataset[T: Tensor](TupleDataset[T]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
domain: SequenceDomain[tuple[T, ...]],
|
||||||
|
lookback: int,
|
||||||
|
offset: int = 0,
|
||||||
|
lookahead: int = 1,
|
||||||
|
num_windows: int = 1,
|
||||||
|
**kwargs: Unpack[DatasetKwargs],
|
||||||
|
) -> None:
|
||||||
|
self.lookback = lookback
|
||||||
|
self.offset = offset
|
||||||
|
self.lookahead = lookahead
|
||||||
|
self.num_windows = num_windows
|
||||||
|
|
||||||
|
super().__init__(domain, **kwargs)
|
||||||
|
|
||||||
|
def _process_batch_data(
|
||||||
|
self,
|
||||||
|
batch_data: tuple[T, ...],
|
||||||
|
batch_index: int,
|
||||||
|
) -> list[tuple[T, ...]]:
|
||||||
|
"""
|
||||||
|
Backward pads first sequence over (lookback-1) length, and steps the
|
||||||
|
remaining items forward by the lookahead.
|
||||||
|
|
||||||
|
Batch data:
|
||||||
|
|
||||||
|
(Tensor[C1, T], ..., Tensor[CN, T])
|
||||||
|
|
||||||
|
+ lookback determines window size; pad left to create lookback size
|
||||||
|
with the first element at the right:
|
||||||
|
|
||||||
|
|-lookback-|
|
||||||
|
[0 ... 0 T0]
|
||||||
|
|
||||||
|
`lookback` is strictly positive, unbounded.
|
||||||
|
|
||||||
|
+ offset shifts the first such window forward. `0 <= offset < L`;
|
||||||
|
think of it as the number of additional non-padded items we
|
||||||
|
slide into the window from the right. At its largest value of `L-1`,
|
||||||
|
we'll have L-sized windows with `T0` as the leftmost element.
|
||||||
|
|
||||||
|
offset=2
|
||||||
|
[0 ... T0 T1 T2]
|
||||||
|
|---lookback---|
|
||||||
|
|
||||||
|
In effect, the index of the rightmost item of the first window
|
||||||
|
will be equal to the value of `offset`. There are `T - offset`
|
||||||
|
total possible windows over a given sequence (regardless of
|
||||||
|
lookback).
|
||||||
|
|
||||||
|
+ lookahead determines the offset of the "label" slices from the
|
||||||
|
first index, regardless of any value of `offset`.
|
||||||
|
|
||||||
|
lookahead=3
|
||||||
|
[0 ... 0 T0] T1 T2 [T3]
|
||||||
|
|-lookback-|
|
||||||
|
|
||||||
|
0 [0 .. T0 T1] T2 T3 [T4]
|
||||||
|
|-lookback-|
|
||||||
|
|
||||||
|
There are `T - lookahead` allowed slices, assuming the lookahead
|
||||||
|
exceeds the offset.
|
||||||
|
|
||||||
|
To get windows starting with the first index at the left: we first set
|
||||||
|
out window size (call it L), determined by `lookback`. Then the
|
||||||
|
rightmost index we want will be `L-1`, which determines our `offset`
|
||||||
|
setting.
|
||||||
|
|
||||||
|
lookahead=L, offset=L-1
|
||||||
|
[ T_0 ... T_{L-1} ]
|
||||||
|
|
||||||
|
To get a one-step lookahead in front of that rightmost item, the
|
||||||
|
`lookahead` can be set to the index of the first label we want:
|
||||||
|
|
||||||
|
lookback=L, offset=L-1 lookahead=L
|
||||||
|
[ T_0 ... T_{L-1} ] [ T_L ]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.pre_transform is not None:
|
||||||
|
batch_data = self.pre_transform(batch_data)
|
||||||
|
|
||||||
|
lb = self.lookback
|
||||||
|
off = min(self.offset, lb-1)
|
||||||
|
la = self.lookahead
|
||||||
|
|
||||||
|
ws = []
|
||||||
|
for t in batch_data[:self.num_windows]:
|
||||||
|
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
|
||||||
|
# the amount of our offset, which in the extreme cases does no
|
||||||
|
# padding.
|
||||||
|
xip = F.pad(t, ((lb-1) - off, 0))
|
||||||
|
|
||||||
|
# extract sliding windows over the padded tensor
|
||||||
|
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for
|
||||||
|
# `lb`-sized windows. We turn (C_i, pad+T) shape into
|
||||||
|
# (C_i, T-offset, lb), giving `T-offset` total `lb` windows.
|
||||||
|
wi = xip.unfold(-1, lb, 1)
|
||||||
|
|
||||||
|
# (C_i, T-offset, lb) -> (T-offset, C_i, lb)
|
||||||
|
wi = wi.permute(1, 0, 2)
|
||||||
|
|
||||||
|
# if lookahead exceeds offset, there are some windows for which we
|
||||||
|
# won't be able to assign a lookahead label. Cut those off here
|
||||||
|
if la - off > 0:
|
||||||
|
wi = wi[:-(la-off)]
|
||||||
|
|
||||||
|
ws.append(wi)
|
||||||
|
|
||||||
|
ys = []
|
||||||
|
for t in batch_data[self.num_windows:]:
|
||||||
|
# tensors (C_i, T) shaped, align with lookahead, giving a
|
||||||
|
# (C_i, T-lookahead)
|
||||||
|
y = t[:, la:]
|
||||||
|
|
||||||
|
# (C_i, T-lookahead) -> (T-lookahead, C_i)
|
||||||
|
y = y.permute(1, 0)
|
||||||
|
|
||||||
|
# cut off any elements if offset exceeds lookahead
|
||||||
|
if off - la > 0:
|
||||||
|
y = y[:-(off-la)]
|
||||||
|
|
||||||
|
ys.append(y)
|
||||||
|
|
||||||
|
return list(zip(*ws, *ys, strict=True))
|
||||||
|
|
||||||
85
trainlib/domain.py
Normal file
85
trainlib/domain.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
Defines a knowledge domain. Wraps a Dataset / Simulator / Knowledge
|
||||||
|
|
||||||
|
Downstream exploration might include
|
||||||
|
|
||||||
|
- Calibrating Simulator / Knowledge with a Dataset
|
||||||
|
- Amending Dataset with Simulator / Knowledge
|
||||||
|
- Positioning Knowledge within Simulator context
|
||||||
|
* Where to replace Simulator subsystem with Knowledge?
|
||||||
|
|
||||||
|
Other variations:
|
||||||
|
|
||||||
|
- Multi-fidelity simulators
|
||||||
|
- Multi-scale models
|
||||||
|
- Multi-system
|
||||||
|
- Incomplete knowledge / divergence among sources
|
||||||
|
|
||||||
|
Questions:
|
||||||
|
|
||||||
|
- Should Simulator / Knowledge be unified as one (e.g., "Expert")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Iterator, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class Domain[U, R](Mapping[U, R]):
|
||||||
|
"""
|
||||||
|
Domain base class, generic to a URI type ``U`` and resource type ``R``.
|
||||||
|
|
||||||
|
Domains are just Mappings where the iterator behavior is specifically typed
|
||||||
|
to range over keys (URIs). Defining a specific class here gives us a base
|
||||||
|
for a nominal hierarchy, but functionally the Mapping core (sized iterables
|
||||||
|
with accessors).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, uri: U) -> R:
|
||||||
|
"""
|
||||||
|
Get the resource for a given URI (call-based alias).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self[uri]
|
||||||
|
|
||||||
|
def __getitem__(self, uri: U) -> R:
|
||||||
|
"""
|
||||||
|
Get the resource for a given URI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[U]:
|
||||||
|
"""
|
||||||
|
Provide an iterator over domain URIs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Measure the size the domain.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceDomain[R](Domain[int, R]):
|
||||||
|
"""
|
||||||
|
Trivial domain implementation for wrapping sequences that can be seen as
|
||||||
|
Mappings with 0-indexed keys.
|
||||||
|
|
||||||
|
Why define this? Domains provide iterators over their *keys*, sequences
|
||||||
|
often iterate over *values*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sequence: Sequence[R]) -> None:
|
||||||
|
self.sequence = sequence
|
||||||
|
|
||||||
|
def __getitem__(self, uri: int) -> R:
|
||||||
|
return self.sequence[uri]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
return iter(range(len(self.sequence)))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.sequence)
|
||||||
0
trainlib/domains/__init__.py
Normal file
0
trainlib/domains/__init__.py
Normal file
37
trainlib/domains/disk.py
Normal file
37
trainlib/domains/disk.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from mema.domain import Domain
|
||||||
|
|
||||||
|
|
||||||
|
class DiskDomain(Domain[Path, bytes]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: Path,
|
||||||
|
extensions: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
extensions: list of file extensions to filter for when determining
|
||||||
|
data file paths to read. This is a whitelist, except when left
|
||||||
|
as ``None`` (which is default), in which case all extensions
|
||||||
|
all allowed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.root = root
|
||||||
|
self.extensions = extensions
|
||||||
|
|
||||||
|
def __getitem__(self, uri: Path) -> bytes:
|
||||||
|
return uri.read_bytes()
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Path]:
|
||||||
|
return (
|
||||||
|
path for path in self.root.iterdir()
|
||||||
|
if path.is_file() and (
|
||||||
|
self.extensions is None
|
||||||
|
or path.suffix in self.extensions
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return sum(1 for _ in iter(self))
|
||||||
58
trainlib/domains/functional.py
Normal file
58
trainlib/domains/functional.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from collections.abc import Callable, Iterator, Sequence
|
||||||
|
|
||||||
|
from trainlib.domain import Domain
|
||||||
|
|
||||||
|
|
||||||
|
class SimulatorDomain[P, R](Domain[int, R]):
|
||||||
|
"""
|
||||||
|
Base simulator domain, generic to arbitrary callables.
|
||||||
|
|
||||||
|
Note: we don't store simulation results here; that's left to a downstream
|
||||||
|
object, like a `BatchedDataset`, to cache if needed. We also don't subclass
|
||||||
|
`SequenceDataset` because the item getter type doesn't align: we accept an
|
||||||
|
`int` in the parameter list, but don't return the items directly from that
|
||||||
|
collection (we transform them first).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
simulator: Callable[[P], R],
|
||||||
|
parameters: Sequence[P],
|
||||||
|
) -> None:
|
||||||
|
self.simulator = simulator
|
||||||
|
self.parameters = parameters
|
||||||
|
|
||||||
|
def __getitem__(self, uri: int) -> R:
|
||||||
|
return self.simulator(self.parameters[uri])
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
return iter(range(len(self.parameters)))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.parameters)
|
||||||
|
|
||||||
|
|
||||||
|
class SimulatorPredictiveDomain[P, R](Domain[int, tuple[R, ...]]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
simulator: Callable[[P], R],
|
||||||
|
parameters: Sequence[P],
|
||||||
|
predictives: Sequence[Callable[[R], R]],
|
||||||
|
) -> None:
|
||||||
|
self.simulator = simulator
|
||||||
|
self.parameters = parameters
|
||||||
|
self.predictives = predictives
|
||||||
|
|
||||||
|
def __getitem__(self, uri: int) -> tuple[R, ...]:
|
||||||
|
sample = self.simulator(self.parameters[uri])
|
||||||
|
|
||||||
|
return (
|
||||||
|
sample,
|
||||||
|
*(p(sample) for p in self.predictives)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
return iter(range(len(self.parameters)))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.parameters)
|
||||||
195
trainlib/estimator.py
Normal file
195
trainlib/estimator.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""
|
||||||
|
Development note
|
||||||
|
|
||||||
|
I'd rather lay out bare args and kwargs in the estimator methods, but the
|
||||||
|
allowed variance in subclasses makes this difficult. To ease the pain of typing
|
||||||
|
coordination with datasets, `loss()` / `forward()` / `metrics()` specify their
|
||||||
|
arguments through a TypedDict. This makes those signatures more portable, and
|
||||||
|
can bound type variables in other places.
|
||||||
|
|
||||||
|
In theory, even the single tensor input base currently in place could be
|
||||||
|
relaxed given the allowed variability in dataset / dataloader outputs. The
|
||||||
|
default collate function in the PyTorch `DataLoader` source leaves types like
|
||||||
|
strings and bytes unchanged, so not all kinds of data are batched into a
|
||||||
|
`Sequence`, let alone a tensor. Nevertheless, estimators are `nn.Module`
|
||||||
|
derivatives, so it's at minimum a safe assumption we'll need at least one
|
||||||
|
tensor (or *should* have one).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Unpack, TypedDict
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from torch import nn, Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from trainlib.util.type import OptimizerKwargs
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorKwargs(TypedDict):
|
||||||
|
inputs: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Estimator[Kw: EstimatorKwargs](nn.Module):
|
||||||
|
"""
|
||||||
|
Estimator base class.
|
||||||
|
|
||||||
|
All methods that raise ``NotImplementedErrors`` are directly invoked in the
|
||||||
|
``Trainer.train(...)`` loop.
|
||||||
|
|
||||||
|
Note the flexibility afforded to the signatures of `forward()`, `loss()`,
|
||||||
|
and `metrics()` methods, which should generally have identical sets of
|
||||||
|
arguments in inheriting classes. The base class is generic to a type `K`,
|
||||||
|
which should be a `TypedDict` for inheriting classes (despite not being
|
||||||
|
enforceable as an upper bound) that reflects the keyword argument
|
||||||
|
structure for these methods.
|
||||||
|
|
||||||
|
For instance, in a sequence prediction model with labels and masking, you
|
||||||
|
might have something like:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class PredictorKwargs(TypedDict, total=False):
|
||||||
|
labels: list[Tensor]
|
||||||
|
lengths: Tensor
|
||||||
|
mask: Tensor
|
||||||
|
|
||||||
|
class SequencePredictor(Estimator[PredictorKwargs]):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: Tensor,
|
||||||
|
**kwargs: Unpack[PredictorKwargs],
|
||||||
|
) -> tuple[Tensor, ...]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def loss(
|
||||||
|
self,
|
||||||
|
input: Tensor,
|
||||||
|
**kwargs: Unpack[PredictorKwargs],
|
||||||
|
) -> Generator[Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def metrics(
|
||||||
|
self,
|
||||||
|
input: Tensor,
|
||||||
|
**kwargs: Unpack[PredictorKwargs],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
...
|
||||||
|
|
||||||
|
While `loss` and `metrics` should leverage the full set of keyword
|
||||||
|
arguments, `forward` may not (e.g., in the example above, it shouldn't use
|
||||||
|
`labels`).
|
||||||
|
|
||||||
|
Subclasses of `SequencePredictor` should then be generic over subtypes of
|
||||||
|
`PredictorKwargs`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[Kw],
|
||||||
|
) -> tuple[Tensor, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def loss(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[Kw],
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Compute model loss for the given input.
|
||||||
|
|
||||||
|
Note that the loss is implemented as a generator to support
|
||||||
|
multi-objective estimator setups. That is, losses can be yielded in
|
||||||
|
sequence, allowing for a training loop to propagate model parameter
|
||||||
|
updates before the next loss function is calculated. For instance, in a
|
||||||
|
GAN-like setup, one might first emit the D-loss, update D parameters,
|
||||||
|
then compute the G-loss (*depending* on the updated D parameters). Such
|
||||||
|
a scheme is not otherwise possible without bringing the intermediate
|
||||||
|
parameter update "in house" (breaking a separation of duties with the
|
||||||
|
train loop).
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def metrics(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[Kw],
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute metrics for the given input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def optimizers(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[OptimizerKwargs],
|
||||||
|
) -> tuple[Optimizer, ...]:
|
||||||
|
"""
|
||||||
|
Get optimizers for the estimator to use in training loops.
|
||||||
|
|
||||||
|
Example providing a singular Adam-based optimizer:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
def optimizers(...):
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
self.parameters(),
|
||||||
|
lr=1e-3,
|
||||||
|
eps=1e-8,
|
||||||
|
)
|
||||||
|
return (optimizer,)
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def epoch_step(self) -> None:
|
||||||
|
"""
|
||||||
|
Step epoch-dependent model state.
|
||||||
|
|
||||||
|
This method should not include optimization of primary model
|
||||||
|
parameters; that should be left to external optimizers. Instead, this
|
||||||
|
method should step forward things like internal hyperparameter
|
||||||
|
schedules in accordance with the expected call rate (e.g., every
|
||||||
|
epoch).
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def epoch_write(
|
||||||
|
self,
|
||||||
|
writer: SummaryWriter,
|
||||||
|
step: int | None = None,
|
||||||
|
val: bool = False,
|
||||||
|
**kwargs: Unpack[Kw],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Write epoch-dependent Tensorboard items.
|
||||||
|
|
||||||
|
If implemented, this should supplement that which is provided in
|
||||||
|
``metrics()``. Tensors provided as ``input`` should include raw
|
||||||
|
training/validation data; examples include writing raw embeddings,
|
||||||
|
canonical visualizations of the samples (e.g., previewing images), etc.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input: batch of tensors from the current epoch
|
||||||
|
writer: tensorboard writer instance
|
||||||
|
step: current step in the optimization loop
|
||||||
|
val: whether input is a validation sample
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def log_arch(self) -> None:
|
||||||
|
"""
|
||||||
|
Log Estimator architecture details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(f"> Estimator :: {self.__class__.__name__}")
|
||||||
|
|
||||||
|
num_params = sum(
|
||||||
|
p.numel() for p in self.parameters() if p.requires_grad
|
||||||
|
)
|
||||||
|
logger.info(f"| > # model parameters: {num_params}")
|
||||||
0
trainlib/estimators/__init__.py
Normal file
0
trainlib/estimators/__init__.py
Normal file
491
trainlib/estimators/rnn.py
Normal file
491
trainlib/estimators/rnn.py
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Unpack, NotRequired
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn, Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from mema.estimator import Estimator, EstimatorKwargs
|
||||||
|
from mema.util.type import OptimizerKwargs
|
||||||
|
from mema.util.module import get_grad_norm
|
||||||
|
from mema.estimators.tdnn import TDNNLayer
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RNNKwargs(EstimatorKwargs):
|
||||||
|
inputs: Tensor
|
||||||
|
labels: NotRequired[Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class LSTM[K: RNNKwargs](Estimator[K]):
|
||||||
|
"""
|
||||||
|
Base RNN architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
hidden_dim: int = 64,
|
||||||
|
num_layers: int = 4,
|
||||||
|
bidirectional: bool = False,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
input_dim: dimensionality of the input
|
||||||
|
output_dim: dimensionality of the output
|
||||||
|
rnn_dim: dimensionality of each RNN layer output
|
||||||
|
num_layers: number of LSTM layers pairs to use
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.dense_in = nn.Linear(input_dim, hidden_dim)
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
)
|
||||||
|
|
||||||
|
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
|
||||||
|
self.dense_z = nn.Linear(lstm_out_dim, output_dim)
|
||||||
|
|
||||||
|
# weight initialization for LSTM layers
|
||||||
|
def init_weights(m: nn.Module) -> None:
|
||||||
|
if isinstance(m, nn.LSTM):
|
||||||
|
for name, p in m.named_parameters():
|
||||||
|
if "weight_ih" in name:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
elif "weight_hh" in name:
|
||||||
|
nn.init.orthogonal_(p)
|
||||||
|
elif "bias" in name:
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
self.log_arch()
|
||||||
|
|
||||||
|
def _clamp_rand(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.clamp(
|
||||||
|
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
||||||
|
min=-1.0,
|
||||||
|
max=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||||
|
inputs = kwargs["inputs"]
|
||||||
|
|
||||||
|
# data shaped (B, C, T); map to (B, T, C)
|
||||||
|
x = inputs.permute(0, 2, 1)
|
||||||
|
x = torch.tanh(self.dense_in(x))
|
||||||
|
x = self._clamp_rand(x)
|
||||||
|
x, hidden = self.lstm(x)
|
||||||
|
z = self.dense_z(x)
|
||||||
|
|
||||||
|
return z[:, -1, :], hidden
|
||||||
|
|
||||||
|
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||||
|
predictions = self(**kwargs)[0]
|
||||||
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
|
yield F.mse_loss(predictions, labels)
|
||||||
|
|
||||||
|
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||||
|
with torch.no_grad():
|
||||||
|
loss = next(self.loss(**kwargs)).item()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": loss,
|
||||||
|
"grad_norm": get_grad_norm(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
def optimizers(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[OptimizerKwargs],
|
||||||
|
) -> tuple[Optimizer, ...]:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_kwargs: Unpack[OptimizerKwargs] = {
|
||||||
|
"lr": 1e-3,
|
||||||
|
"eps": 1e-8,
|
||||||
|
}
|
||||||
|
opt_kwargs = {**default_kwargs, **kwargs}
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
self.parameters(),
|
||||||
|
**opt_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (optimizer,)
|
||||||
|
|
||||||
|
def epoch_step(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def epoch_write(
|
||||||
|
self,
|
||||||
|
writer: SummaryWriter,
|
||||||
|
step: int | None = None,
|
||||||
|
val: bool = False,
|
||||||
|
**kwargs: Unpack[K],
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def log_arch(self) -> None:
|
||||||
|
super().log_arch()
|
||||||
|
|
||||||
|
logger.info(f"| > {self.input_dim=}")
|
||||||
|
logger.info(f"| > {self.hidden_dim=}")
|
||||||
|
logger.info(f"| > {self.num_layers=}")
|
||||||
|
logger.info(f"| > {self.output_dim=}")
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadLSTMKwargs(EstimatorKwargs):
|
||||||
|
inputs: Tensor
|
||||||
|
labels: NotRequired[Tensor]
|
||||||
|
auxiliary: NotRequired[Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
hidden_dim: int = 64,
|
||||||
|
num_layers: int = 4,
|
||||||
|
bidirectional: bool = False,
|
||||||
|
head_dims: list[int] | None = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.head_dims = head_dims if head_dims is not None else []
|
||||||
|
|
||||||
|
self.dense_in = nn.Linear(input_dim, hidden_dim)
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
)
|
||||||
|
|
||||||
|
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
|
||||||
|
self.dense_z_out = nn.Linear(lstm_out_dim, output_dim)
|
||||||
|
self.dense_z_heads = nn.ModuleList([
|
||||||
|
nn.Linear(lstm_out_dim, head_dim)
|
||||||
|
for head_dim in self.head_dims
|
||||||
|
])
|
||||||
|
|
||||||
|
# weight initialization for LSTM layers
|
||||||
|
def init_weights(m: nn.Module) -> None:
|
||||||
|
if isinstance(m, nn.LSTM):
|
||||||
|
for name, p in m.named_parameters():
|
||||||
|
if "weight_ih" in name:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
elif "weight_hh" in name:
|
||||||
|
nn.init.orthogonal_(p)
|
||||||
|
elif "bias" in name:
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
self.log_arch()
|
||||||
|
|
||||||
|
def _clamp_rand(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.clamp(
|
||||||
|
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
||||||
|
min=-1.0,
|
||||||
|
max=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||||
|
inputs = kwargs["inputs"]
|
||||||
|
|
||||||
|
# data shaped (B, C, T); map to (B, T, C)
|
||||||
|
x = inputs.permute(0, 2, 1)
|
||||||
|
x = torch.tanh(self.dense_in(x))
|
||||||
|
x = self._clamp_rand(x)
|
||||||
|
x, hidden = self.lstm(x)
|
||||||
|
|
||||||
|
z = self.dense_z_out(x)
|
||||||
|
zs = torch.cat([head(x) for head in self.dense_z_heads], dim=-1)
|
||||||
|
|
||||||
|
return z[:, -1, :], zs[:, -1, :]
|
||||||
|
|
||||||
|
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||||
|
pred, pred_aux = self(**kwargs)
|
||||||
|
labels = kwargs["labels"]
|
||||||
|
aux_labels = kwargs.get("auxiliary")
|
||||||
|
|
||||||
|
if aux_labels is None:
|
||||||
|
yield F.mse_loss(pred, labels)
|
||||||
|
else:
|
||||||
|
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
|
||||||
|
|
||||||
|
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||||
|
with torch.no_grad():
|
||||||
|
loss = next(self.loss(**kwargs)).item()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": loss,
|
||||||
|
"grad_norm": get_grad_norm(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
def optimizers(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[OptimizerKwargs],
|
||||||
|
) -> tuple[Optimizer, ...]:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_kwargs: Unpack[OptimizerKwargs] = {
|
||||||
|
"lr": 1e-3,
|
||||||
|
"eps": 1e-8,
|
||||||
|
}
|
||||||
|
opt_kwargs = {**default_kwargs, **kwargs}
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
self.parameters(),
|
||||||
|
**opt_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (optimizer,)
|
||||||
|
|
||||||
|
def epoch_step(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def epoch_write(
|
||||||
|
self,
|
||||||
|
writer: SummaryWriter,
|
||||||
|
step: int | None = None,
|
||||||
|
val: bool = False,
|
||||||
|
**kwargs: Unpack[K],
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def log_arch(self) -> None:
|
||||||
|
super().log_arch()
|
||||||
|
|
||||||
|
logger.info(f"| > {self.input_dim=}")
|
||||||
|
logger.info(f"| > {self.hidden_dim=}")
|
||||||
|
logger.info(f"| > {self.num_layers=}")
|
||||||
|
logger.info(f"| > {self.output_dim=}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConvRNN[K: RNNKwargs](Estimator[K]):
|
||||||
|
"""
|
||||||
|
Base recurrent convolutional architecture.
|
||||||
|
|
||||||
|
Computes latents, initial states, and rate estimates from features and
|
||||||
|
lambda parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
temporal_dim: int,
|
||||||
|
gru_dim: int = 64,
|
||||||
|
conv_dim: int = 96,
|
||||||
|
num_layers: int = 4,
|
||||||
|
conv_kernel_sizes: list[int] | None = None,
|
||||||
|
conv_dilations: list[int] | None = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
input_dim: dimensionality of the input
|
||||||
|
output_dim: dimensionality of the output
|
||||||
|
gru_dim: dimensionality of each GRU layer output
|
||||||
|
conv_dim: dimensionality of each conv layer output
|
||||||
|
num_layers: number of gru-conv layer pairs to use
|
||||||
|
conv_kernel_sizes: kernel sizes for conv layers
|
||||||
|
conv_dilations: dilation settings for conv layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
self.gru_dim = gru_dim
|
||||||
|
self.conv_dim = conv_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.receptive_field = 0
|
||||||
|
|
||||||
|
self.conv_kernel_sizes: list[int]
|
||||||
|
if conv_kernel_sizes is None:
|
||||||
|
self.conv_kernel_sizes = [4] * num_layers
|
||||||
|
else:
|
||||||
|
self.conv_kernel_sizes = conv_kernel_sizes
|
||||||
|
|
||||||
|
self.conv_dilations: list[int]
|
||||||
|
if conv_dilations is None:
|
||||||
|
self.conv_dilations = [1] + [2] * (num_layers - 1)
|
||||||
|
else:
|
||||||
|
self.conv_dilations = conv_dilations
|
||||||
|
|
||||||
|
self._gru_layers: nn.ModuleList = nn.ModuleList()
|
||||||
|
self._conv_layers: nn.ModuleList = nn.ModuleList()
|
||||||
|
|
||||||
|
layer_in_dim = gru_dim
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
gru_layer = nn.GRU(layer_in_dim, gru_dim, batch_first=True)
|
||||||
|
self._gru_layers.append(gru_layer)
|
||||||
|
layer_in_dim += gru_dim
|
||||||
|
|
||||||
|
tdnn_layer = TDNNLayer(
|
||||||
|
layer_in_dim,
|
||||||
|
conv_dim,
|
||||||
|
kernel_size=self.conv_kernel_sizes[i],
|
||||||
|
dilation=self.conv_dilations[i],
|
||||||
|
#pad=False,
|
||||||
|
)
|
||||||
|
self.receptive_field += tdnn_layer.receptive_field
|
||||||
|
|
||||||
|
self._conv_layers.append(tdnn_layer)
|
||||||
|
layer_in_dim += conv_dim
|
||||||
|
|
||||||
|
# self.dense_in = nn.Linear(self.input_dim, gru_dim)
|
||||||
|
self.dense_in = TDNNLayer(
|
||||||
|
self.input_dim,
|
||||||
|
gru_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
pad=False
|
||||||
|
)
|
||||||
|
# will be (B, T, C), applies indep at each time step across channels
|
||||||
|
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
|
||||||
|
|
||||||
|
# will be (B, C, T), applies indep at each time step across channels
|
||||||
|
self.dense_z = TDNNLayer(
|
||||||
|
layer_in_dim,
|
||||||
|
self.output_dim,
|
||||||
|
kernel_size=temporal_dim,
|
||||||
|
pad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# weight initialization for GRU layers
|
||||||
|
def init_weights(module: nn.Module) -> None:
|
||||||
|
if isinstance(module, nn.GRU):
|
||||||
|
for p in module.named_parameters():
|
||||||
|
if p[0].startswith("weight_hh_"):
|
||||||
|
nn.init.orthogonal_(p[1])
|
||||||
|
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
self.log_arch()
|
||||||
|
|
||||||
|
def _clamp_rand(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.clamp(
|
||||||
|
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
||||||
|
min=-1.0,
|
||||||
|
max=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||||
|
inputs = kwargs["inputs"]
|
||||||
|
|
||||||
|
# embedding shaped (B, C, T)
|
||||||
|
x = self._clamp_rand(torch.tanh(self.dense_in(inputs)))
|
||||||
|
|
||||||
|
# prepare shape (B, T, C) -- for GRU
|
||||||
|
x = x.transpose(-2, -1)
|
||||||
|
|
||||||
|
for gru, conv in zip(self._gru_layers, self._conv_layers, strict=True):
|
||||||
|
xg = self._clamp_rand(gru(x)[0])
|
||||||
|
x = torch.cat([x, xg], -1)
|
||||||
|
|
||||||
|
xc = self._clamp_rand(conv(x.transpose(-2, -1)))
|
||||||
|
xc = xc.transpose(-2, -1)
|
||||||
|
x = torch.cat([x, xc], -1)
|
||||||
|
|
||||||
|
# z = self.dense_z(x)
|
||||||
|
# z = z.transpose(-2, -1)
|
||||||
|
|
||||||
|
x = x.transpose(-2, -1)
|
||||||
|
# map to (B, C, T)
|
||||||
|
z = self.dense_z(x)
|
||||||
|
|
||||||
|
return (z,)
|
||||||
|
|
||||||
|
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||||
|
predictions = self(**kwargs)[0]
|
||||||
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
|
# squeeze last dim; we've mapped T -> 1
|
||||||
|
predictions = predictions.squeeze(-1)
|
||||||
|
|
||||||
|
yield F.mse_loss(predictions, labels, reduction="mean")
|
||||||
|
|
||||||
|
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||||
|
with torch.no_grad():
|
||||||
|
loss = next(self.loss(**kwargs)).item()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": loss,
|
||||||
|
"grad_norm": get_grad_norm(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
def optimizers(
|
||||||
|
self,
|
||||||
|
**kwargs: Unpack[OptimizerKwargs],
|
||||||
|
) -> tuple[Optimizer, ...]:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_kwargs: Unpack[OptimizerKwargs] = {
|
||||||
|
"lr": 1e-3,
|
||||||
|
"eps": 1e-8,
|
||||||
|
}
|
||||||
|
opt_kwargs = {**default_kwargs, **kwargs}
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
self.parameters(),
|
||||||
|
**opt_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (optimizer,)
|
||||||
|
|
||||||
|
def epoch_step(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def epoch_write(
|
||||||
|
self,
|
||||||
|
writer: SummaryWriter,
|
||||||
|
step: int | None = None,
|
||||||
|
val: bool = False,
|
||||||
|
**kwargs: Unpack[K],
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def log_arch(self) -> None:
|
||||||
|
super().log_arch()
|
||||||
|
|
||||||
|
logger.info(f"| > {self.input_dim=}")
|
||||||
|
logger.info(f"| > {self.gru_dim=}")
|
||||||
|
logger.info(f"| > {self.conv_dim=}")
|
||||||
|
logger.info(f"| > {self.num_layers=}")
|
||||||
|
logger.info(f"| > {self.conv_kernel_sizes=}")
|
||||||
|
logger.info(f"| > {self.conv_dilations=}")
|
||||||
|
logger.info(f"| > {self.receptive_field=}")
|
||||||
|
logger.info(f"| > {self.output_dim=}")
|
||||||
114
trainlib/estimators/tdnn.py
Normal file
114
trainlib/estimators/tdnn.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn, Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TDNNLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Time delay neural network layer.
|
||||||
|
|
||||||
|
Built on torch Conv1D layers, with additional support for automatic
|
||||||
|
padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channels: int,
|
||||||
|
output_channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: int = 1,
|
||||||
|
lookahead: int = 0,
|
||||||
|
pad: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Implements a fast TDNN layer via `torch.Conv1d`.
|
||||||
|
|
||||||
|
Note that we're restricted to the kernel shapes producible by `Conv1d`
|
||||||
|
objects, implying (primarily) that the kernel must be symmetric and
|
||||||
|
have equal spacing.
|
||||||
|
|
||||||
|
For example, the symmetric but non-uniform context [-3, -2, 0, +2, +3]
|
||||||
|
cannot be represented, while [-6, -3, 0, 3, 6] can. A few other kernel
|
||||||
|
examples:
|
||||||
|
|
||||||
|
kernel_size=3; dilation=1 -> [-1, 0, 1]
|
||||||
|
kernel_size=3; dilation=3 -> [-3, 0, 3]
|
||||||
|
kernel_size=4; dilation=2 -> [-3, -1, 1, 3]
|
||||||
|
|
||||||
|
By default, the TDNN layer left pads the input to ensure the output has
|
||||||
|
the same sequence length as the original input. For example, with a
|
||||||
|
kernel size of 3 (dilation of 1) and sequence length of T=3, the
|
||||||
|
sequence will be left padded with 2 zeros:
|
||||||
|
|
||||||
|
[0, 0, 1, 1, 1] -> [x1, x2, x3]
|
||||||
|
|
||||||
|
If a lookahead is specified, some number of those left zeros will be
|
||||||
|
moved to the right. If lookahead=1, for instance, indicating 1
|
||||||
|
additional "future frame" of context, padding will look like
|
||||||
|
|
||||||
|
[0, 1, 1, 1, 0] -> [x1, x2, x3]
|
||||||
|
|
||||||
|
The output x_i now sees through time step i+1.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_channels: number of input channels
|
||||||
|
output_channels: number of channels produced by the temporal
|
||||||
|
convolution
|
||||||
|
kernel_size: total size of the kernel
|
||||||
|
dilation: dilation of receptive field, i.e., the size of the gaps
|
||||||
|
lookahead: number of allowed lookahead frames
|
||||||
|
pad: whether the input should be padded, producing an output with
|
||||||
|
the same sequence length as the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.td_conv: nn.Module = weight_norm(
|
||||||
|
nn.Conv1d(
|
||||||
|
input_channels,
|
||||||
|
output_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pad: bool = pad
|
||||||
|
self.lookahead = lookahead
|
||||||
|
self.receptive_field: int = (kernel_size - 1) * dilation + 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.lookahead < self.receptive_field
|
||||||
|
), "Lookahead cannot exceed receptive field"
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Dimension definitions:
|
||||||
|
|
||||||
|
- B: batch size
|
||||||
|
- Di: input dimension at single time step (aka input channels)
|
||||||
|
- Do: output dimension at single time step (aka output channels)
|
||||||
|
- T: sequence length
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x: input tensor, shaped [B, Di, T] (optionally w/o a batch dim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor shaped [B, Do, T-kernel_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pad according to receptive field and lookahead s.t. output
|
||||||
|
# shape [B, *, T]
|
||||||
|
if self.pad:
|
||||||
|
x = F.pad(
|
||||||
|
x,
|
||||||
|
(self.receptive_field - self.lookahead - 1, self.lookahead)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.td_conv(x)
|
||||||
509
trainlib/trainer.py
Normal file
509
trainlib/trainer.py
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Self
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch import cuda, Tensor
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from trainlib.dataset import BatchedDataset
|
||||||
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||||
|
from trainlib.transform import Transform
|
||||||
|
from trainlib.util.type import (
|
||||||
|
SplitKwargs,
|
||||||
|
LoaderKwargs,
|
||||||
|
BalanceKwargs,
|
||||||
|
)
|
||||||
|
from trainlib.util.module import ModelWrapper
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer[I, K: EstimatorKwargs]:
|
||||||
|
"""
|
||||||
|
Training interface for updating ``Estimators`` with ``Datasets``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
estimator: Estimator[K],
|
||||||
|
device: str | None = None,
|
||||||
|
chkpt_dir: str = "chkpt/",
|
||||||
|
tblog_dir: str = "tblog/",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
estimator: `Estimator` model object
|
||||||
|
device: device on which to carry out training
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.device: str
|
||||||
|
if device is None:
|
||||||
|
self.device = "cuda" if cuda.is_available() else "cpu"
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
logger.info(f"> Trainer device: {self.device}")
|
||||||
|
if self.device.startswith("cuda"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
# extra cuda details
|
||||||
|
logger.info(f"| > {cuda.device_count()=}")
|
||||||
|
logger.info(f"| > {cuda.current_device()=}")
|
||||||
|
logger.info(f"| > {cuda.get_device_name()=}")
|
||||||
|
logger.info(f"| > {cuda.get_device_capability()=}")
|
||||||
|
|
||||||
|
# memory info (in GB)
|
||||||
|
gb = 1024**3
|
||||||
|
memory_allocated = cuda.memory_allocated() / gb
|
||||||
|
memory_reserved = cuda.memory_reserved() / gb
|
||||||
|
memory_total = cuda.get_device_properties(0).total_memory / gb
|
||||||
|
|
||||||
|
logger.info("| > CUDA memory:")
|
||||||
|
logger.info(f"| > {memory_total=:.2f}GB")
|
||||||
|
logger.info(f"| > {memory_reserved=:.2f}GB")
|
||||||
|
logger.info(f"| > {memory_allocated=:.2f}GB")
|
||||||
|
else:
|
||||||
|
logger.warning("| > CUDA device specified but not available")
|
||||||
|
else:
|
||||||
|
logger.info("| > Using CPU device - no additional device info")
|
||||||
|
|
||||||
|
self.estimator = estimator
|
||||||
|
self.estimator.to(self.device)
|
||||||
|
|
||||||
|
self.chkpt_dir = Path(chkpt_dir).resolve()
|
||||||
|
self.tblog_dir = Path(tblog_dir).resolve()
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""
|
||||||
|
Set base tracking parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._step: int = 0
|
||||||
|
self._epoch: int = 0
|
||||||
|
self._summary: dict[str, list[tuple[float, int]]] = defaultdict(list)
|
||||||
|
|
||||||
|
self._val_loss = float("inf")
|
||||||
|
self._best_val_loss = float("inf")
|
||||||
|
self._stagnant_epochs = 0
|
||||||
|
self._best_model_state_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
dataset: BatchedDataset[..., ..., I],
|
||||||
|
batch_estimator_map: Callable[[I, Self], K],
|
||||||
|
lr: float = 1e-3,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
max_grad_norm: float | None = None,
|
||||||
|
max_epochs: int = 10,
|
||||||
|
stop_after_epochs: int = 5,
|
||||||
|
batch_size: int = 256,
|
||||||
|
val_frac: float = 0.1,
|
||||||
|
train_transform: Transform | None = None,
|
||||||
|
val_transform: Transform | None = None,
|
||||||
|
dataset_split_kwargs: SplitKwargs | None = None,
|
||||||
|
dataset_balance_kwargs: BalanceKwargs | None = None,
|
||||||
|
dataloader_kwargs: LoaderKwargs | None = None,
|
||||||
|
summarize_every: int = 1,
|
||||||
|
chkpt_every: int = 1,
|
||||||
|
resume_latest: bool = False,
|
||||||
|
summary_writer: SummaryWriter | None = None,
|
||||||
|
) -> Estimator:
|
||||||
|
"""
|
||||||
|
Note: this method attempts to implement a general scheme for passing
|
||||||
|
needed items to the estimator's loss function from the dataloader. The
|
||||||
|
abstract `Estimator` base only requires the model output be provided
|
||||||
|
for any given loss calculation, but concrete estimators will often
|
||||||
|
require additional arguments (e.g., labels or length masks, as
|
||||||
|
is the case with sequential models). In any case, this method defers
|
||||||
|
any further logic to the `loss` method of the underlying estimator, so
|
||||||
|
one should take care to synchronize the sample structure with `dataset`
|
||||||
|
to match that expected by `self.estimator.loss(...)`.
|
||||||
|
|
||||||
|
On batch_estimator_map:
|
||||||
|
|
||||||
|
Dataloader collate functions are responsible for mapping a collection
|
||||||
|
of items into an item of collections, roughly speaking. If items are
|
||||||
|
tuples of tensors,
|
||||||
|
|
||||||
|
[
|
||||||
|
( [1, 1], [1, 1] ),
|
||||||
|
( [2, 2], [2, 2] ),
|
||||||
|
( [3, 3], [3, 3] ),
|
||||||
|
]
|
||||||
|
|
||||||
|
the collate function maps back into the item skeleton, producing a
|
||||||
|
single tuple of (stacked) tensors
|
||||||
|
|
||||||
|
( [[1, 1],
|
||||||
|
[2, 2],
|
||||||
|
[3, 3]],
|
||||||
|
|
||||||
|
[[1, 1],
|
||||||
|
[2, 2],
|
||||||
|
[3, 3]] )
|
||||||
|
|
||||||
|
This function should map from batches (which should be *item shaped*,
|
||||||
|
i.e., have an `I` skeleton, even if stacked items may be different on
|
||||||
|
the inside) into estimator keyword arguments (type `K`).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
lr: learning rate (default: 1e-3)
|
||||||
|
eps: adam EPS (default: 1e-8)
|
||||||
|
max_epochs: maximum number of training epochs
|
||||||
|
stop_after_epochs: number of epochs with stagnant validation losses
|
||||||
|
to allow before early stopping. If training stops earlier, the
|
||||||
|
parameters for the best recorded validation score are loaded
|
||||||
|
into the estimator before the method returns. If
|
||||||
|
`stop_after_epochs >= max_epochs`, the estimator will train
|
||||||
|
over all epochs and return as is, irrespective of validation
|
||||||
|
scores.
|
||||||
|
batch_size: size of batch to use when training on the provided
|
||||||
|
dataset
|
||||||
|
val_split_frac: fraction of dataset to use for validation
|
||||||
|
chkpt_every: how often model checkpoints should be saved
|
||||||
|
resume_latest: resume training from the latest available checkpoint
|
||||||
|
in the `chkpt_dir`
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("> Begin train loop:")
|
||||||
|
logger.info(f"| > {lr=}")
|
||||||
|
logger.info(f"| > {eps=}")
|
||||||
|
logger.info(f"| > {max_epochs=}")
|
||||||
|
logger.info(f"| > {batch_size=}")
|
||||||
|
logger.info(f"| > {val_frac=}")
|
||||||
|
logger.info(f"| > {chkpt_every=}")
|
||||||
|
logger.info(f"| > {resume_latest=}")
|
||||||
|
logger.info(f"| > with device: {self.device}")
|
||||||
|
logger.info(f"| > core count: {os.cpu_count()}")
|
||||||
|
|
||||||
|
writer: SummaryWriter
|
||||||
|
dir_prefix = str(int(time.time()))
|
||||||
|
if summary_writer is None:
|
||||||
|
writer = SummaryWriter(f"{self.tblog_dir}")
|
||||||
|
else:
|
||||||
|
writer = summary_writer
|
||||||
|
|
||||||
|
train_loader, val_loader = self.get_dataloaders(
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
val_frac=val_frac,
|
||||||
|
train_transform=train_transform,
|
||||||
|
val_transform=val_transform,
|
||||||
|
dataset_split_kwargs=dataset_split_kwargs,
|
||||||
|
dataset_balance_kwargs=dataset_balance_kwargs,
|
||||||
|
dataloader_kwargs=dataloader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||||
|
|
||||||
|
self._step = 0
|
||||||
|
self._epoch = 1 # start from 1 for logging convenience
|
||||||
|
while self._epoch <= max_epochs and not self._converged(
|
||||||
|
self._epoch, stop_after_epochs
|
||||||
|
):
|
||||||
|
print(f"Training epoch {self._epoch}/{max_epochs}...")
|
||||||
|
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...")
|
||||||
|
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
train_loss_sums = []
|
||||||
|
self.estimator.train()
|
||||||
|
with tqdm(train_loader, unit="batch") as train_epoch:
|
||||||
|
for i, batch_data in enumerate(train_epoch):
|
||||||
|
est_kwargs = batch_estimator_map(batch_data, self)
|
||||||
|
inputs = est_kwargs["inputs"]
|
||||||
|
|
||||||
|
# one-time logging
|
||||||
|
if self._step == 0:
|
||||||
|
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
|
||||||
|
|
||||||
|
# once-per-epoch logging
|
||||||
|
if i == 0:
|
||||||
|
self.estimator.epoch_write(
|
||||||
|
writer,
|
||||||
|
step=self._step,
|
||||||
|
val=False,
|
||||||
|
**est_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
train_losses = self.estimator.loss(**est_kwargs)
|
||||||
|
train_loss_items = []
|
||||||
|
for o_idx, optimizer in enumerate(optimizers):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
train_loss = next(train_losses)
|
||||||
|
|
||||||
|
if len(train_loss_sums) <= o_idx:
|
||||||
|
train_loss_sums.append(0.0)
|
||||||
|
|
||||||
|
train_loss_item = train_loss.item()
|
||||||
|
train_loss_sums[o_idx] += train_loss_item
|
||||||
|
train_loss_items.append(train_loss_item)
|
||||||
|
|
||||||
|
train_loss.backward()
|
||||||
|
|
||||||
|
# clip gradients for optimizer's parameters
|
||||||
|
if max_grad_norm is not None:
|
||||||
|
opt_params = self._get_optimizer_parameters(optimizer)
|
||||||
|
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
self._step += len(inputs)
|
||||||
|
|
||||||
|
for train_loss_item, train_loss_sum in zip(
|
||||||
|
train_loss_items,
|
||||||
|
train_loss_sums,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
||||||
|
self._add_summary_item("train_loss", train_loss_item)
|
||||||
|
|
||||||
|
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||||
|
for metric_name, metric_value in estimator_metrics.items():
|
||||||
|
self._add_summary_item(f"train_{metric_name}", metric_value)
|
||||||
|
|
||||||
|
self.estimator.epoch_step()
|
||||||
|
|
||||||
|
for li, train_loss_sum in enumerate(train_loss_sums):
|
||||||
|
self._add_summary_item(
|
||||||
|
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_frac > 0:
|
||||||
|
val_loss_sums = []
|
||||||
|
self.estimator.eval()
|
||||||
|
with tqdm(val_loader, unit="batch") as val_epoch:
|
||||||
|
for i, batch_data in enumerate(val_epoch):
|
||||||
|
est_kwargs = batch_estimator_map(batch_data, self)
|
||||||
|
inputs = est_kwargs["inputs"]
|
||||||
|
|
||||||
|
# once-per-epoch logging
|
||||||
|
if i == 0:
|
||||||
|
self.estimator.epoch_write(
|
||||||
|
writer,
|
||||||
|
step=self._step,
|
||||||
|
val=True,
|
||||||
|
**est_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
val_losses = self.estimator.loss(**est_kwargs)
|
||||||
|
val_loss_items = []
|
||||||
|
for o_idx in range(len(optimizers)):
|
||||||
|
val_loss = next(val_losses)
|
||||||
|
|
||||||
|
if len(val_loss_sums) <= o_idx:
|
||||||
|
val_loss_sums.append(0.0)
|
||||||
|
|
||||||
|
val_loss_item = val_loss.item()
|
||||||
|
val_loss_sums[o_idx] += val_loss_item
|
||||||
|
val_loss_items.append(val_loss_item)
|
||||||
|
|
||||||
|
for val_loss_item, val_loss_sum in zip(
|
||||||
|
val_loss_items,
|
||||||
|
val_loss_sums,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
||||||
|
self._add_summary_item("val_loss", val_loss_item)
|
||||||
|
|
||||||
|
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||||
|
for metric_name, metric_value in estimator_metrics.items():
|
||||||
|
self._add_summary_item(f"val_{metric_name}", metric_value)
|
||||||
|
|
||||||
|
for li, val_loss_sum in enumerate(val_loss_sums):
|
||||||
|
self._add_summary_item(
|
||||||
|
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
||||||
|
)
|
||||||
|
|
||||||
|
# convergence of multiple losses may be ambiguous
|
||||||
|
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||||
|
|
||||||
|
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
|
||||||
|
|
||||||
|
if self._epoch % summarize_every == 0:
|
||||||
|
self._summarize(writer, self._epoch)
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
if self._epoch % chkpt_every == 0:
|
||||||
|
self.save_model(
|
||||||
|
self._epoch, self.chkpt_dir, dir_prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
self._epoch += 1
|
||||||
|
|
||||||
|
return self.estimator
|
||||||
|
|
||||||
|
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
||||||
|
converged = False
|
||||||
|
|
||||||
|
if epoch == 1 or self._val_loss < self._best_val_loss:
|
||||||
|
self._best_val_loss = self._val_loss
|
||||||
|
self._stagnant_epochs = 0
|
||||||
|
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
||||||
|
else:
|
||||||
|
self._stagnant_epochs += 1
|
||||||
|
|
||||||
|
if self._stagnant_epochs >= stop_after_epochs:
|
||||||
|
self.estimator.load_state_dict(self._best_model_state_dict)
|
||||||
|
converged = True
|
||||||
|
|
||||||
|
return converged
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dataloaders(
|
||||||
|
dataset: BatchedDataset,
|
||||||
|
batch_size: int,
|
||||||
|
val_frac: float = 0.1,
|
||||||
|
train_transform: Transform | None = None,
|
||||||
|
val_transform: Transform | None = None,
|
||||||
|
dataset_split_kwargs: SplitKwargs | None = None,
|
||||||
|
dataset_balance_kwargs: BalanceKwargs | None = None,
|
||||||
|
dataloader_kwargs: LoaderKwargs | None = None,
|
||||||
|
) -> tuple[DataLoader, DataLoader]:
|
||||||
|
"""
|
||||||
|
Create training and validation dataloaders for the provided dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if dataset_split_kwargs is None:
|
||||||
|
dataset_split_kwargs = {}
|
||||||
|
|
||||||
|
if dataset_balance_kwargs is not None:
|
||||||
|
dataset.balance(**dataset_balance_kwargs)
|
||||||
|
|
||||||
|
if val_frac <= 0:
|
||||||
|
dataset.post_transform = train_transform
|
||||||
|
train_loader_kwargs: LoaderKwargs = {
|
||||||
|
"batch_size": min(batch_size, len(dataset)),
|
||||||
|
"num_workers": 0,
|
||||||
|
"shuffle": True,
|
||||||
|
}
|
||||||
|
if dataloader_kwargs is not None:
|
||||||
|
train_loader_kwargs: LoaderKwargs = {
|
||||||
|
**train_loader_kwargs,
|
||||||
|
**dataloader_kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
DataLoader(dataset, **train_loader_kwargs),
|
||||||
|
DataLoader(Dataset())
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset, val_dataset = dataset.split(
|
||||||
|
[1 - val_frac, val_frac],
|
||||||
|
**dataset_split_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset.split() returns light Subset objects of shallow copies of the
|
||||||
|
# underlying dataset; can change the transform attribute of both splits
|
||||||
|
# w/o overwriting
|
||||||
|
train_dataset.post_transform = train_transform
|
||||||
|
val_dataset.post_transform = val_transform
|
||||||
|
|
||||||
|
train_loader_kwargs: LoaderKwargs = {
|
||||||
|
"batch_size": min(batch_size, len(train_dataset)),
|
||||||
|
"num_workers": 0,
|
||||||
|
"shuffle": True,
|
||||||
|
}
|
||||||
|
val_loader_kwargs: LoaderKwargs = {
|
||||||
|
"batch_size": min(batch_size, len(val_dataset)),
|
||||||
|
"num_workers": 0,
|
||||||
|
"shuffle": True, # shuffle to prevent homogeneous val batches
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataloader_kwargs is not None:
|
||||||
|
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
|
||||||
|
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
|
||||||
|
|
||||||
|
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
|
||||||
|
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
|
||||||
|
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
def _summarize(self, writer: SummaryWriter, epoch: int) -> None:
|
||||||
|
"""
|
||||||
|
Flush the training summary to the TB summary writer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
summary_values = defaultdict(list)
|
||||||
|
for name, records in self._summary.items():
|
||||||
|
for value, step in records:
|
||||||
|
writer.add_scalar(name, value, step)
|
||||||
|
summary_values[name].append(value)
|
||||||
|
|
||||||
|
print(f"==== Epoch [{epoch}] summary ====")
|
||||||
|
for name, values in summary_values.items():
|
||||||
|
mean_value = torch.tensor(values).mean().item()
|
||||||
|
print(f"> ({len(values)}) {name} :: {mean_value:.2f}")
|
||||||
|
|
||||||
|
writer.flush()
|
||||||
|
self._summary = defaultdict(list)
|
||||||
|
|
||||||
|
def _get_optimizer_parameters(
|
||||||
|
self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
) -> list[Tensor]:
|
||||||
|
return [
|
||||||
|
param
|
||||||
|
for param_group in optimizer.param_groups
|
||||||
|
for param in param_group["params"]
|
||||||
|
if param.grad is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
def _add_summary_item(self, name: str, value: float) -> None:
|
||||||
|
self._summary[name].append((value, self._step))
|
||||||
|
|
||||||
|
def save_model(
|
||||||
|
self,
|
||||||
|
epoch: int,
|
||||||
|
chkpt_dir: str | Path,
|
||||||
|
dir_prefix: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a model checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_buff = BytesIO()
|
||||||
|
torch.save(self.estimator.state_dict(), model_buff)
|
||||||
|
model_buff.seek(0)
|
||||||
|
|
||||||
|
model_class = self.estimator.__class__.__name__
|
||||||
|
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
|
||||||
|
|
||||||
|
chkpt_dir = Path(chkpt_dir, dir_prefix)
|
||||||
|
chkpt_path = Path(chkpt_dir, chkpt_name)
|
||||||
|
|
||||||
|
chkpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
chkpt_path.write_bytes(model_buff.getvalue())
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
self,
|
||||||
|
epoch: int,
|
||||||
|
chkpt_dir: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Load a model checkpoint from a given epoch.
|
||||||
|
|
||||||
|
Note that this assumes the model was saved via `Trainer.save_model()`,
|
||||||
|
and the estimator provided to this `Trainer` instance matches the
|
||||||
|
architecture of the checkpoint model being loaded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_class = self.estimator.__class__.__name__
|
||||||
|
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
|
||||||
|
chkpt_path = Path(chkpt_dir, chkpt_name)
|
||||||
|
|
||||||
|
model_buff = BytesIO(chkpt_path.read_bytes())
|
||||||
|
model_buff.seek(0)
|
||||||
|
|
||||||
|
model_dict = torch.load(model_buff, weights_only=True)
|
||||||
|
self.estimator.load_state_dict(model_dict)
|
||||||
11
trainlib/transform.py
Normal file
11
trainlib/transform.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
class Transform[I]:
|
||||||
|
"""
|
||||||
|
Dataset transform base class.
|
||||||
|
|
||||||
|
In places that directly reference a base ``Transform[I]``, a hint
|
||||||
|
``Callable[[I], I]`` would suffice. This class exists to allow nominal
|
||||||
|
checks for purpose-built transforms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, item: I) -> I:
|
||||||
|
raise NotImplementedError
|
||||||
0
trainlib/utils/__init__.py
Normal file
0
trainlib/utils/__init__.py
Normal file
54
trainlib/utils/job.py
Normal file
54
trainlib/utils/job.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import logging
|
||||||
|
import concurrent
|
||||||
|
from concurrent.futures import Future, as_completed
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from colorama import Fore, Style
|
||||||
|
|
||||||
|
from mema.util.text import color_text
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def process_futures(
|
||||||
|
futures: list[Future],
|
||||||
|
desc: str | None = None,
|
||||||
|
unit: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if desc is None:
|
||||||
|
desc = "Awaiting futures"
|
||||||
|
|
||||||
|
if unit is None:
|
||||||
|
unit = "it"
|
||||||
|
|
||||||
|
success = 0
|
||||||
|
cancelled = 0
|
||||||
|
errored = 0
|
||||||
|
submitted = len(futures)
|
||||||
|
progress_bar = tqdm(
|
||||||
|
total=len(futures),
|
||||||
|
desc=f"{desc} [submitted {len(futures)}]",
|
||||||
|
unit=unit,
|
||||||
|
)
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
try:
|
||||||
|
future.result()
|
||||||
|
success += 1
|
||||||
|
except concurrent.futures.CancelledError as e:
|
||||||
|
cancelled += 1
|
||||||
|
logger.error(f'Future cancelled; "{e}"')
|
||||||
|
except Exception as e:
|
||||||
|
errored += 1
|
||||||
|
logger.warning(f'Future failed with unknown exception "{e}"')
|
||||||
|
|
||||||
|
suc_txt = color_text(f"{success}", Fore.GREEN)
|
||||||
|
can_txt = color_text(f"{cancelled}", Fore.YELLOW)
|
||||||
|
err_txt = color_text(f"{errored}", Fore.RED)
|
||||||
|
tot_txt = color_text(f"{success+cancelled+errored}", Style.BRIGHT)
|
||||||
|
progress_bar.set_description(
|
||||||
|
f"{desc} [{tot_txt} / {submitted} | {suc_txt} {can_txt} {err_txt}]"
|
||||||
|
)
|
||||||
|
progress_bar.update(n=1)
|
||||||
|
|
||||||
|
progress_bar.close()
|
||||||
25
trainlib/utils/module.py
Normal file
25
trainlib/utils/module.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWrapper(nn.Module):
|
||||||
|
def __init__(self, model: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
#def forward(self, inputs, **kwargs):
|
||||||
|
#return self.model(**{"inputs": inputs, **kwargs})
|
||||||
|
def forward(self, kwargs):
|
||||||
|
return self.model(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_grad_norm(model: nn.Module, p: int = 2) -> float:
|
||||||
|
norm = 0
|
||||||
|
for param in model.parameters():
|
||||||
|
if not param.requires_grad or param.grad is None:
|
||||||
|
continue
|
||||||
|
grad_item = torch.abs(param.grad).pow(p).sum().item()
|
||||||
|
norm += float(grad_item)
|
||||||
|
|
||||||
|
return norm ** (1 / p)
|
||||||
|
|
||||||
8
trainlib/utils/text.py
Normal file
8
trainlib/utils/text.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from colorama import Style
|
||||||
|
|
||||||
|
|
||||||
|
def color_text(text: str, *colorama_args: Any) -> str:
|
||||||
|
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"
|
||||||
|
|
||||||
52
trainlib/utils/type.py
Normal file
52
trainlib/utils/type.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from typing import Any, TypedDict
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
|
from mema.dataset import BatchedDataset
|
||||||
|
|
||||||
|
|
||||||
|
class LoaderKwargs(TypedDict, total=False):
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
sampler: Sampler | Iterable | None
|
||||||
|
batch_sampler: Sampler[list] | Iterable[list] | None
|
||||||
|
num_workers: int
|
||||||
|
collate_fn: Callable[[list], Any]
|
||||||
|
pin_memory: bool
|
||||||
|
drop_last: bool
|
||||||
|
timeout: float
|
||||||
|
worker_init_fn: Callable[[int], None]
|
||||||
|
multiprocessing_context: object
|
||||||
|
generator: object
|
||||||
|
prefetch_factor: int
|
||||||
|
persistent_workers: bool
|
||||||
|
pin_memory_device: str
|
||||||
|
in_order: bool
|
||||||
|
|
||||||
|
|
||||||
|
class SplitKwargs(TypedDict, total=False):
|
||||||
|
dataset: BatchedDataset | None
|
||||||
|
by_attr: str | list[str | None] | None
|
||||||
|
shuffle_strata: bool
|
||||||
|
|
||||||
|
|
||||||
|
class BalanceKwargs(TypedDict, total=False):
|
||||||
|
by_attr: str | list[str | None] | None
|
||||||
|
split_min_sizes: list[int] | None
|
||||||
|
split_max_sizes: list[int] | None
|
||||||
|
shuffle_strata: bool
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerKwargs(TypedDict, total=False):
|
||||||
|
lr: float | Tensor
|
||||||
|
betas: tuple[float | Tensor, float | Tensor]
|
||||||
|
eps: float
|
||||||
|
weight_decay: float
|
||||||
|
amsgrad: bool
|
||||||
|
maximize: bool
|
||||||
|
foreach: bool | None
|
||||||
|
capturable: bool
|
||||||
|
differentiable: bool
|
||||||
|
fused: bool | None
|
||||||
Reference in New Issue
Block a user