initial commit
This commit is contained in:
20
.gitignore
vendored
Normal file
20
.gitignore
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
# generic
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
.python-version
|
||||
|
||||
# package-specific
|
||||
.ipynb_checkpoints/
|
||||
.pytest_cache/
|
||||
|
||||
# vendor/build files
|
||||
dist/
|
||||
build/
|
||||
doc/_autoref/
|
||||
doc/_autosummary/
|
||||
doc/_build/
|
||||
|
||||
# misc local
|
||||
/Makefile
|
||||
notebooks/
|
||||
|
||||
105
README.md
Normal file
105
README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Overview
|
||||
Package summary goes here, ideally with a diagram
|
||||
|
||||
# Install
|
||||
Installation instructions
|
||||
|
||||
```sh
|
||||
pip install <package>
|
||||
```
|
||||
|
||||
or as a CLI tool
|
||||
|
||||
```sh
|
||||
uv tool install <package>
|
||||
```
|
||||
|
||||
# Development
|
||||
- Initialize/synchronize the project with `uv sync`, creating a virtual
|
||||
environment with base package dependencies.
|
||||
- Depending on needs, install the development dependencies with `uv sync
|
||||
--extra dev`.
|
||||
|
||||
# Testing
|
||||
- To run the unit tests, make sure to first have the test dependencies
|
||||
installed with `uv sync --extra test`, then run `make test`.
|
||||
- For notebook testing, run `make install-kernel` to make the environment
|
||||
available as a Jupyter kernel (to be selected when running notebooks).
|
||||
|
||||
# Documentation
|
||||
- Install the documentation dependencies with `uv sync --extra doc`.
|
||||
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
|
||||
locally with `docs-serve`.
|
||||
|
||||
# Development remarks
|
||||
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a
|
||||
`ParamSpec`-based typing scheme to better orchestrate alignment in the
|
||||
`Trainer.train()` loop, e.g., so we can statically check whether a dataset
|
||||
appears to be fulfilling the argument requirements for the estimator's
|
||||
`loss()` / `metrics()` methods. Something like
|
||||
|
||||
```py
|
||||
class Estimator[**P](nn.Module):
|
||||
def loss(
|
||||
self,
|
||||
input: Tensor,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Generator:
|
||||
...
|
||||
|
||||
class Trainer[**P]:
|
||||
def __init__(
|
||||
self,
|
||||
estimator: Estimator[P],
|
||||
...
|
||||
): ...
|
||||
```
|
||||
|
||||
might be how we begin threading signatures. But ensuring dataset items can
|
||||
match `P` is challenging. You can consider a "packed" object where we
|
||||
obfuscate passing data through `P`-signatures:
|
||||
|
||||
```py
|
||||
class PackedItem[**P]:
|
||||
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
|
||||
def apply[R](self, func: Callable[P, R]) -> R:
|
||||
return func(*self._args, **self._kwargs)
|
||||
|
||||
|
||||
class BatchedDataset[U, R, I, **P](Dataset):
|
||||
@abstractmethod
|
||||
def _process_item_data(
|
||||
self,
|
||||
item_data: I,
|
||||
item_index: int,
|
||||
) -> PackedItem[P]:
|
||||
...
|
||||
|
||||
def __iter__(self) -> Iterator[PackedItem[P]]:
|
||||
...
|
||||
```
|
||||
|
||||
Meaningfully shaping those signatures is what remains, but you can't really
|
||||
do this, not with typical type expression flexibility. For instance, if I'm
|
||||
trying to appropriately type my base `TupleDataset`:
|
||||
|
||||
```py
|
||||
class SequenceDataset[I, **P](HomogenousDataset[int, I, I, P]):
|
||||
...
|
||||
|
||||
class TupleDataset[I](SequenceDataset[tuple[I, ...], ??]):
|
||||
...
|
||||
```
|
||||
|
||||
Here there's no way for me to shape a `ParamSpec` to indicate arbitrarily
|
||||
many arguments of a fixed type (`I` in this case) to allow me to unpack my
|
||||
item tuples into an appropriate `PackedItem`.
|
||||
|
||||
Until this (among other issues) becomes clearer, I'm setting up around a
|
||||
simpler `TypedDict` type variable. We won't have particularly strong static
|
||||
checks for item alignment inside `Trainer`, but this seems about as good as I
|
||||
can get around the current infrastructure.
|
||||
84
pyproject.toml
Normal file
84
pyproject.toml
Normal file
@@ -0,0 +1,84 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "trainlib"
|
||||
version = "0.1.0"
|
||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||
requires-python = ">=3.13"
|
||||
authors = [
|
||||
{ name="Sam Griesemer", email="git@olog.io" },
|
||||
]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
keywords = [
|
||||
"machine-learning",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python",
|
||||
"Operating System :: OS Independent",
|
||||
"Development Status :: 3 - Alpha",
|
||||
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: End Users/Desktop",
|
||||
]
|
||||
dependencies = [
|
||||
"colorama>=0.4.6",
|
||||
"matplotlib>=3.10.8",
|
||||
"numpy>=2.4.1",
|
||||
"tensorboard>=2.20.0",
|
||||
"torch>=2.5.1",
|
||||
"tqdm>=4.67.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
symconf = "trainlib.__main__:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"ipykernel",
|
||||
]
|
||||
doc = [
|
||||
"furo",
|
||||
"myst-parser",
|
||||
"sphinx",
|
||||
"sphinx-togglebutton",
|
||||
"sphinx-autodoc-typehints",
|
||||
]
|
||||
test = [
|
||||
"pytest",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://doc.olog.io/trainlib"
|
||||
Documentation = "https://doc.olog.io/trainlib"
|
||||
Repository = "https://git.olog.io/olog/trainlib"
|
||||
Issues = "https://git.olog.io/olog/trainlib/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["trainlib*"]
|
||||
|
||||
# for static data files under package root
|
||||
# [tool.setuptools.package-data]
|
||||
# "<package>" = ["data/*.toml"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["ANN", "E", "F", "UP", "B", "SIM", "I", "C4", "PERF"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
length-sort = true
|
||||
order-by-type = false
|
||||
force-sort-within-sections = false
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**" = ["S101"]
|
||||
"**/__init__.py" = ["F401"]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
docstring-code-format = true
|
||||
0
trainlib/__init__.py
Normal file
0
trainlib/__init__.py
Normal file
964
trainlib/dataset.py
Normal file
964
trainlib/dataset.py
Normal file
@@ -0,0 +1,964 @@
|
||||
"""
|
||||
Marginalizing out the modality layer:
|
||||
|
||||
With ``domain`` being an instance variable, one possible interpretation of
|
||||
the object structures here is that one could completely abstract away
|
||||
the domain model, defining only item structures and processing data. You
|
||||
could have a single dataset definition for a particular concrete dataset,
|
||||
and so long as we're talking about the same items, it can be instantiated
|
||||
using *any domain*. You wouldn't need specific subclasses for disk or
|
||||
network or in-memory; you can tell it directly at runtime.
|
||||
|
||||
That's an eventually possibility, anyway. As it stands, however, this is
|
||||
effectively impossible:
|
||||
|
||||
You can't easily abstract the batch -> item splitting process, i.e.,
|
||||
``_process_batch_data()``. A list-based version of the dataset you're
|
||||
trying to define might have an individual item tuple at every index,
|
||||
whereas a disk-based version might have tuples batched across a few files.
|
||||
This can't reliably be inferred, nor can it be pushed to the
|
||||
``Domain``-level without needing equal levels of specialization (you'd just
|
||||
end up needing the exact same structural distinctions in the ``Domain``
|
||||
hierarchy). So *somewhere* you need a batch splitting implementation that
|
||||
is both item structure-dependent *and* domain-dependent...the question is
|
||||
how dynamic you're willing to be about where it comes from. Right now, we
|
||||
require this actually be defined in the ``_process_batch_data()`` method,
|
||||
meaning you'll need a specific ``Dataset`` class for each domain you want
|
||||
to support (e.g., ``MNISTDisk``, ``MNISTList``, ``MNISTNetwork``, etc), or
|
||||
at least for each domain where "interpreting" a batch could possibly
|
||||
differ. This is a case where the interface is all that enforces a
|
||||
distinction: if you've got two domains that can be counted on to yield
|
||||
batches in the exact same way and can use the same processing, then you
|
||||
could feasibly provide ``Domain`` objects from either at runtime and have
|
||||
no issues. We're "structurally blind" to any differentiation beyond the URI
|
||||
and resource types by design, so two different domain implementations with
|
||||
the same type signature ``Domain[U, R]`` should be expected to work fine at
|
||||
runtime (again, so long as they don't also need different batch
|
||||
processing), but that's not affording us much flexibility, i.e., most of
|
||||
the time we'll still be defining new dataset classes for each domain.
|
||||
|
||||
I initially flagged this as feasible, however, because one could imagine
|
||||
accepting a batch processing method upon instantiation rather than
|
||||
structurally bolting it into the ``Dataset`` definition. This would require
|
||||
knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so
|
||||
such a function will always have to be (I, U, R)-dependent. It nevertheless
|
||||
would take out some of the pain of having to define new dataset classes;
|
||||
instead, you'd just need to define the batch processing method. I see this
|
||||
as a worse alternative to just defining *inside* a safe context like a new
|
||||
dataset class: you know the types you have to respect, and you stick that
|
||||
method exactly in a context where it's understood. Freeing this up doesn't
|
||||
lighten the burden of processing logic, it just changes *when* it has to be
|
||||
provided, and that's not worth much (to me) in this case given the bump in
|
||||
complexity. (Taking this to the extreme: you could supply *all* of an
|
||||
object's methods "dynamically" and glue them together at runtime so long as
|
||||
they all played nice. But wherever you were "laying them out" beforehand is
|
||||
exactly the job of a class to begin with, so you don't end up with anything
|
||||
more dynamic. All we're really discussing here is pushing around
|
||||
unavoidable complexity inside and outside of the "class walls," and in the
|
||||
particular case of ``_process_batch_data()``, it feels much better when
|
||||
it's on the inside.)
|
||||
|
||||
Holding:
|
||||
@abstractmethod
|
||||
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
|
||||
Get URI groups for each batch.
|
||||
|
||||
If there's more than one URI per batch (e.g., a data file and a
|
||||
metadata file), zip the URIs such that we have a tuple of URIs per
|
||||
batch.
|
||||
|
||||
Note that this effectively defines the index style over batches in the
|
||||
attached domain. We get an ``int -> tuple[U, ...]`` map that turns
|
||||
batch indices into URIs that can be read under the domain.
|
||||
``get_batch()`` turns an integer index into its corresponding
|
||||
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
|
||||
the tuple, treating them as providers of batched data.
|
||||
``_read_resources()`` passes through to the attached domain logic,
|
||||
which, although common, need not supply an explicit iterable of batch
|
||||
items: we just access items with ``__getitem__()`` and may ask for
|
||||
``__len__``. So the returned URI group collection (this method) does
|
||||
need to be iterable to measure the number of batches, but the batch
|
||||
objects that are ultimately produced by these URI groups need not be
|
||||
iterables themselves.
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _read_resources(
|
||||
self,
|
||||
uri_group: tuple[U, ...],
|
||||
batch_index: int
|
||||
) -> tuple[R, ...]:
|
||||
Read batch files at the provided paths.
|
||||
|
||||
This method should operate on a single tuple from the list of batch
|
||||
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
|
||||
all of the resources for a single batch and returns a tuple of the same
|
||||
size with their contents.
|
||||
|
||||
Note: the dependence on a batch index is mostly here to make
|
||||
multi-dataset composition easier later. In-dataset, you don't need to
|
||||
know the batch index to to simply process URIs, but across datasets you
|
||||
need it to find out the origin of the batch (and process those URIs
|
||||
accordingly).
|
||||
|
||||
return tuple(self.domain.read(uri) for uri in uri_group)
|
||||
|
||||
# pulling the type variable out of the inline generic b/c `ty` has trouble
|
||||
# understanding bound type variables in subclasses (specifically with Self@)
|
||||
T = TypeVar("T", bound=NamedTuple)
|
||||
|
||||
class NamedTupleDataset[I](Dataset):
|
||||
def __init__(self, data_list: list[I]) -> None:
|
||||
self.data_list = data_list
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data_list)
|
||||
|
||||
def __getitem__(self, index: int) -> I:
|
||||
return self.data_list[index]
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from copy import copy
|
||||
from bisect import bisect
|
||||
from typing import Unpack, TypedDict
|
||||
from functools import lru_cache
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from trainlib.utils import job
|
||||
from trainlib.domain import Domain, SequenceDomain
|
||||
from trainlib.transform import Transform
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetKwargs[I](TypedDict, total=False):
|
||||
pre_transform: Transform[I]
|
||||
post_transform: Transform[I]
|
||||
batch_cache_limit: int
|
||||
preload: bool
|
||||
num_workers: int
|
||||
|
||||
|
||||
class BatchedDataset[U, R, I](Dataset):
|
||||
"""
|
||||
Generic dataset that dynamically pulls batched data from resources (e.g.,
|
||||
files, remote locations, streams, etc).
|
||||
|
||||
The class is generic over a URI type ``U``, a resource type ``R`` (both of
|
||||
which are used to concretize a domain ``Domain[U, R]``), and an item type
|
||||
``T`` (which has a ``tuple`` upper bound).
|
||||
|
||||
Pipeline overview:
|
||||
|
||||
```
|
||||
Domain -> [U] (get _batch_uris)
|
||||
U -> R (domain access ; Rs provide batches)
|
||||
R -> [I] (cache here ; _process_batch_data to use load_transform)
|
||||
[I] -> I (human item obj ; _get_item)
|
||||
I -> **P (final packed item ; __getitem__ to use transform)
|
||||
```
|
||||
|
||||
Note^1: as far as positioning, this class is meant to play nice with
|
||||
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
|
||||
value add for this over the ``torch.Dataset`` base is almost entirely in
|
||||
the logic it implements to map out of *batched resources* that are holding
|
||||
data, and flattening it out into typical dataset items. There are also some
|
||||
QoL items when it comes to splitting and balancing samples.
|
||||
|
||||
Note^2: even though ``Domains`` implement iterators over their URIs, this
|
||||
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk
|
||||
over the resources that provide data, but we don't necessarily presuppose
|
||||
an ordered walk over samples within batches. Point being:
|
||||
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
||||
superclass, even when we're working around iterable ``Domains``.
|
||||
|
||||
Note^3: transforms are expected to operate on ``I``-items and produce
|
||||
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
|
||||
other intermediate representation, nor should they map from ``I`` to
|
||||
something else. Point being: the dataset definition should be able to map
|
||||
resources ``R`` to ``I`` without a transform: that much should be baked
|
||||
into the class definition. If you find you're expecting the transform to do
|
||||
that for you, you should consider pulling in some common structure across
|
||||
the allowed transforms and make it a fixed part of the class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: Domain[U, R],
|
||||
pre_transform: Transform[I] | None = None,
|
||||
post_transform: Transform[I] | None = None,
|
||||
batch_cache_limit: int | None = None,
|
||||
preload: bool = False,
|
||||
num_workers: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
pre_transform: transform to apply over items during loading (in
|
||||
``_process_batch_data()``), i.e., *before* going into
|
||||
persistent storage
|
||||
post_transform: transform to apply just prior to returning an item
|
||||
(in ``_process_item_data()``), i.e., only *after* retrieval
|
||||
from persistent storage
|
||||
batch_cache_limit: the max number of max batches to cache at any
|
||||
one time
|
||||
preload: whether to load all data into memory during instantiation
|
||||
"""
|
||||
|
||||
self.domain = domain
|
||||
self.pre_transform = pre_transform
|
||||
self.post_transform = post_transform
|
||||
self.batch_cache_limit = batch_cache_limit
|
||||
self.num_workers = num_workers
|
||||
|
||||
logger.info("Fetching URIs...")
|
||||
self._batch_uris: list[U] = list(domain)
|
||||
|
||||
self._indices: list[int] | None = None
|
||||
self._dataset_len: int | None = None
|
||||
self._num_batches: int = len(domain)
|
||||
|
||||
self.get_batch: Callable[[int], list[I]] = lru_cache(
|
||||
maxsize=batch_cache_limit
|
||||
)(self._get_batch)
|
||||
|
||||
if preload:
|
||||
self.load_all(num_workers=num_workers)
|
||||
|
||||
@abstractmethod
|
||||
def _get_dataset_len(self) -> int:
|
||||
"""
|
||||
Calculate the total dataset size in units of samples (not batches).
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return the index of the batch containing the item at the provided item
|
||||
index, and the index of the item within that batch.
|
||||
|
||||
The behavior of this method can vary depending on what we know about
|
||||
batch sizes, and should therefore be implemented by inheriting classes.
|
||||
|
||||
Returns:
|
||||
batch_index: int
|
||||
index_in_batch: int
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _process_batch_data(
|
||||
self,
|
||||
batch_data: R,
|
||||
batch_index: int,
|
||||
) -> list[I]:
|
||||
"""
|
||||
Process raw domain resource data (e.g., parse JSON or load tensors) and
|
||||
split accordingly, such that the returned batch is a collection of
|
||||
``I`` items.
|
||||
|
||||
If an inheriting class wants to allow dynamic transforms, this is the
|
||||
place to use a provided ``pre_transform``; the collection of items
|
||||
produced by this method are cached as a batch, so results from such a
|
||||
transform will be stored.
|
||||
|
||||
Parameters:
|
||||
batch_data: tuple of resource data
|
||||
batch_index: index of batch
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _process_item_data(
|
||||
self,
|
||||
item_data: I,
|
||||
item_index: int,
|
||||
) -> I:
|
||||
"""
|
||||
Process individual items and produce final tuples.
|
||||
|
||||
If an inheriting class wants to allow dynamic transforms, this is the
|
||||
place to use a provided ``post_transform``; items are pulled from the
|
||||
cache (if enabled) and processed before being returned as the final
|
||||
tuple outputs (so this processing is not persistent).
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_item(self, item_index: int) -> I:
|
||||
"""
|
||||
Get the item data and zip with the item header.
|
||||
|
||||
Items should be the most granular representation of dataset samples
|
||||
with maximum detail (i.e., yield a all available information). An
|
||||
iterator over these representations across all samples can be retrieved
|
||||
with `.items()`.
|
||||
|
||||
Note that return values from `__getitem__()` are "cleaned up" versions
|
||||
of this representation, with minimal info needed for training.
|
||||
"""
|
||||
|
||||
if item_index >= len(self):
|
||||
raise IndexError
|
||||
|
||||
# alt indices redefine index count
|
||||
item_index = self.indices[item_index]
|
||||
batch_index, index_in_batch = self._get_batch_for_item(item_index)
|
||||
|
||||
return self.get_batch(batch_index)[index_in_batch]
|
||||
|
||||
def _get_batch(self, batch_index: int) -> list[I]:
|
||||
"""
|
||||
Return the batch data for the provided index.
|
||||
|
||||
Note that we require a list return type. This is where the rubber meets
|
||||
the road in terms of expanding batches: the outputs here get cached, if
|
||||
caching is enabled. If we were to defer batch expansion to some later
|
||||
stage, that caching will be more or less worthless. For instance, if
|
||||
``._process_batch_data()`` was an iterator, at best the batch
|
||||
processing logic could be delayed, but then you'd either 1) cache the
|
||||
iterator reference, or 2) have to further delay caching until
|
||||
post-batch processing. Further, ``._process_batch_data()`` can (and
|
||||
often does) depend on the entire batch, so you can't handle that
|
||||
item-wise: you've got to pass all batch data in and can't act on
|
||||
slices, so doing this would be irrelevant anyway.
|
||||
|
||||
How about iterators out from ``_read_resources()``? Then your
|
||||
``_process_batch_data()`` can iterate as needed and do the work later?
|
||||
The point here is that this distinction is irrelevant because you're
|
||||
reading resources and processing the data here in the same method:
|
||||
they're always connected, and nothing would notice if you waited
|
||||
between steps. The only way this could matter is if you split the
|
||||
resource reading and batch processing steps across methods, but when it
|
||||
actually comes to accessing/caching the batch, you'd have to expand
|
||||
any delayed reads here. There's no way around needing to see all batch
|
||||
data at once here, and we don't want to make that ambiguous: ``list``
|
||||
output type it is.
|
||||
"""
|
||||
|
||||
logger.debug("Batch cache miss, reading from root...")
|
||||
|
||||
if batch_index >= self._num_batches:
|
||||
raise IndexError
|
||||
|
||||
batch_uri = self._batch_uris[batch_index]
|
||||
batch_data = self.domain[batch_uri]
|
||||
|
||||
return self._process_batch_data(batch_data, batch_index)
|
||||
|
||||
def load_all(self, num_workers: int | None = None) -> list[list[I]]:
|
||||
"""
|
||||
Preload all data batches into the cache.
|
||||
|
||||
Can be useful when dynamically pulling data (as it's requested) isn't
|
||||
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
|
||||
continually remove previous batches as they're loaded.
|
||||
"""
|
||||
|
||||
assert self.batch_cache_limit is None, "Preloading under cache limit"
|
||||
|
||||
if num_workers is None:
|
||||
num_workers = self.num_workers
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=num_workers)
|
||||
|
||||
futures = []
|
||||
for batch_index in range(self._num_batches):
|
||||
future = thread_pool.submit(
|
||||
self.get_batch,
|
||||
batch_index,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
job.process_futures(futures, "Loading dataset batches", "batch")
|
||||
thread_pool.shutdown(wait=True)
|
||||
|
||||
return [future.result() for future in futures]
|
||||
|
||||
def split(
|
||||
self,
|
||||
fracs: list[float],
|
||||
dataset: "BatchedDataset | None" = None,
|
||||
by_attr: str | list[str | None] | None = None,
|
||||
shuffle_strata: bool = True,
|
||||
) -> list["BatchedDataset"]:
|
||||
"""
|
||||
Split dataset into fractional pieces by data attribute.
|
||||
|
||||
If `by_attr` is None, recovers typical fractional splitting of dataset
|
||||
items, partitioning by size. Using None anywhere will index each item
|
||||
into its own bucket, i.e., by its index. For instance,
|
||||
|
||||
- by_attr=["color"] -> {("red", 1), ("red", 2)},
|
||||
{("blue", 1), ("blue", 2)}
|
||||
|
||||
Splits on the attribute such that each subset contains entire strata
|
||||
of the attribute. "Homogeneity within clusters"
|
||||
|
||||
- `by_attr=["color", None]` -> {("red", 1), ("blue", 1)},
|
||||
{("red", 2), ("blue", 2)}
|
||||
|
||||
Stratifies by attribute and then splits "by index" within, uniformly
|
||||
grabbing samples across strata to form new clusters. "Homogeneity
|
||||
across clusters"
|
||||
|
||||
Note that the final list of Subsets returned are built from shallow
|
||||
copies of the underlying dataset (i.e., `self`) to allow manual
|
||||
intervention with dataset attributes (e.g., setting the splits to have
|
||||
different `transform`s). This is subject to possibly unexpected
|
||||
behavior if re-caching data or you need a true copy of all data in
|
||||
memory, but should otherwise leave most interactions unchanged.
|
||||
|
||||
Parameters:
|
||||
shuffle_strata: shuffle the strata order before split is drawn. We
|
||||
parameterize this because a dataloader-level shuffle operation
|
||||
will only change the order of the indices in the resulting
|
||||
splits; only a shuffle of items inside the strata can change
|
||||
the actual content of the splits themselves.
|
||||
"""
|
||||
|
||||
if by_attr == []:
|
||||
raise ValueError("Cannot parse empty value list")
|
||||
|
||||
assert (
|
||||
math.isclose(sum(fracs), 1) and sum(fracs) <= 1
|
||||
), "Fractions do not sum to 1"
|
||||
|
||||
if isinstance(by_attr, str) or by_attr is None:
|
||||
by_attr = [by_attr]
|
||||
|
||||
if dataset is None:
|
||||
dataset = self
|
||||
# dataset = DictDataset([
|
||||
# self._get_item(i) for i in range(len(self))
|
||||
# ])
|
||||
|
||||
# group samples by specified attr
|
||||
attr_dict = defaultdict(list)
|
||||
attr_key, by_attr = by_attr[0], by_attr[1:]
|
||||
# for i in range(len(dataset)):
|
||||
for i, item in enumerate(dataset.items()):
|
||||
# item = dataset[i]
|
||||
if attr_key is None:
|
||||
attr_val = i
|
||||
elif attr_key in item:
|
||||
attr_val = item[attr_key]
|
||||
else:
|
||||
raise IndexError(f"Attribute {attr_key} not in dataset item")
|
||||
attr_dict[attr_val].append(i)
|
||||
|
||||
if by_attr == []:
|
||||
attr_keys = list(attr_dict.keys())
|
||||
|
||||
# shuffle keys; randomized group-level split
|
||||
if shuffle_strata:
|
||||
random.shuffle(attr_keys)
|
||||
|
||||
# considering: defer to dataloader shuffle param; should have same
|
||||
# effect shuffle values; has no impact on where the split is drawn
|
||||
# for attr_vals in attr_dict.values():
|
||||
# random.shuffle(attr_vals)
|
||||
|
||||
# fractionally split over attribute keys
|
||||
offset, splits = 0, []
|
||||
for frac in fracs[:-1]:
|
||||
frac_indices = []
|
||||
frac_size = int(frac * len(attr_keys))
|
||||
for j in range(offset, offset + frac_size):
|
||||
j_indices = attr_dict.pop(attr_keys[j])
|
||||
frac_indices.extend(j_indices)
|
||||
offset += frac_size
|
||||
splits.append(frac_indices)
|
||||
|
||||
rem_indices = []
|
||||
for r_indices in attr_dict.values():
|
||||
rem_indices.extend(r_indices)
|
||||
splits.append(rem_indices)
|
||||
|
||||
subsets = []
|
||||
for split in splits:
|
||||
subset = copy(dataset)
|
||||
subset.indices = split
|
||||
subsets.append(subset)
|
||||
|
||||
return subsets
|
||||
else:
|
||||
splits = [[] for _ in range(len(fracs))]
|
||||
for index_split in attr_dict.values():
|
||||
# subset = Subset(dataset, index_split)
|
||||
subset = copy(dataset)
|
||||
subset.indices = index_split
|
||||
subset_splits = self.split(
|
||||
fracs, subset, by_attr, shuffle_strata
|
||||
)
|
||||
|
||||
# unpack stratified splits
|
||||
for i, subset_split in enumerate(subset_splits):
|
||||
# splits[i].extend([
|
||||
# index_split[s] for s in subset_split.indices
|
||||
# ])
|
||||
splits[i].extend(subset_split.indices)
|
||||
|
||||
subsets = []
|
||||
for split in splits:
|
||||
subset = copy(dataset)
|
||||
subset.reset_indices()
|
||||
subset.indices = split
|
||||
subsets.append(subset)
|
||||
|
||||
return subsets
|
||||
|
||||
# considering: defer to dataloader shuffle param; should have same
|
||||
# effect shuffle each split after merging; may otherwise be homogenous
|
||||
# for split in splits:
|
||||
# random.shuffle(split)
|
||||
|
||||
# return [Subset(copy(self), split) for split in splits]
|
||||
|
||||
def balance(
|
||||
self,
|
||||
dataset: "BatchedSubset[U, R, I] | None" = None,
|
||||
by_attr: str | list[str | None] | None = None,
|
||||
split_min_sizes: list[int] | None = None,
|
||||
split_max_sizes: list[int] | None = None,
|
||||
shuffle_strata: bool = True,
|
||||
) -> None:
|
||||
self.indices = self._balance(
|
||||
dataset,
|
||||
by_attr,
|
||||
split_min_sizes,
|
||||
split_max_sizes,
|
||||
shuffle_strata,
|
||||
)
|
||||
|
||||
def _balance(
|
||||
self,
|
||||
dataset: "BatchedSubset[U, R, I] | None" = None,
|
||||
by_attr: str | list[str | None] | None = None,
|
||||
split_min_sizes: list[int] | None = None,
|
||||
split_max_sizes: list[int] | None = None,
|
||||
shuffle_strata: bool = True,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Note: behavior is a little odd for nested behavior; not exactly
|
||||
perfectly uniform throughout. This is a little difficult: you can't
|
||||
exactly know ahead of time the size of the subgroups across splits
|
||||
"""
|
||||
|
||||
if by_attr == []:
|
||||
raise ValueError("Cannot parse empty value list")
|
||||
|
||||
if isinstance(by_attr, str) or by_attr is None:
|
||||
by_attr = [by_attr]
|
||||
|
||||
if dataset is None:
|
||||
dataset = BatchedSubset(self, self.indices)
|
||||
|
||||
if split_min_sizes == [] or split_min_sizes is None:
|
||||
split_min_sizes = [0]
|
||||
if split_max_sizes == [] or split_max_sizes is None:
|
||||
split_max_sizes = [len(dataset)]
|
||||
|
||||
# group samples by specified attr
|
||||
attr_dict = defaultdict(list)
|
||||
attr_key, by_attr = by_attr[0], by_attr[1:]
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
if attr_key is None:
|
||||
attr_val = i
|
||||
elif attr_key in item:
|
||||
attr_val = getattr(item, attr_key)
|
||||
else:
|
||||
raise IndexError(f"Attribute {attr_key} not in dataset item")
|
||||
attr_dict[attr_val].append(i)
|
||||
|
||||
subset_splits = []
|
||||
split_min_size, split_min_sizes = (
|
||||
split_min_sizes[0],
|
||||
split_min_sizes[1:]
|
||||
)
|
||||
split_max_size, split_max_sizes = (
|
||||
split_max_sizes[0],
|
||||
split_max_sizes[1:]
|
||||
)
|
||||
for split_indices in attr_dict.values():
|
||||
if by_attr != []:
|
||||
subset_indices = self._balance(
|
||||
BatchedSubset(dataset, split_indices),
|
||||
by_attr,
|
||||
split_min_sizes,
|
||||
split_max_sizes,
|
||||
shuffle_strata,
|
||||
)
|
||||
split_indices = [split_indices[s] for s in subset_indices]
|
||||
subset_splits.append(split_indices)
|
||||
|
||||
# shuffle splits; randomized group-level split
|
||||
if shuffle_strata:
|
||||
random.shuffle(subset_splits)
|
||||
|
||||
# note: split_min_size is smallest allowed, min_split_size is smallest
|
||||
# observed
|
||||
valid_splits = [
|
||||
ss for ss in subset_splits
|
||||
if len(ss) >= split_min_size
|
||||
]
|
||||
min_split_size = min(len(split) for split in valid_splits)
|
||||
|
||||
subset_indices = []
|
||||
for split_indices in valid_splits:
|
||||
# if shuffle_strata:
|
||||
# random.shuffle(split_indices)
|
||||
subset_indices.extend(
|
||||
split_indices[: min(min_split_size, split_max_size)]
|
||||
)
|
||||
|
||||
# print(f"{attr_dict.keys()=}")
|
||||
# print(f"{[len(s) for s in valid_splits]=}")
|
||||
# print(f"{min_split_size=}")
|
||||
|
||||
return subset_indices
|
||||
|
||||
@property
|
||||
def indices(self) -> list[int]:
|
||||
if self._indices is None:
|
||||
self._indices = list(range(len(self)))
|
||||
return self._indices
|
||||
|
||||
@indices.setter
|
||||
def indices(self, indices: list[int]) -> None:
|
||||
"""
|
||||
Note: this logic facilitates nested re-indexing over the same base
|
||||
dataset. The underlying data remain the same, but when indices get set,
|
||||
you're effectively applying a mask over any existing indices, always
|
||||
operating *relative* to the existing mask.
|
||||
"""
|
||||
|
||||
# manually set new size
|
||||
self._dataset_len = len(indices)
|
||||
|
||||
# note: this is a little tricky and compact; follow what happens when
|
||||
# _indices aren't already set
|
||||
self._indices = [
|
||||
self.indices[index]
|
||||
for index in indices
|
||||
]
|
||||
|
||||
def reset_indices(self) -> None:
|
||||
self._indices = None
|
||||
self._dataset_len = None
|
||||
|
||||
def items(self) -> Iterator[I]:
|
||||
for i in range(len(self)):
|
||||
yield self._get_item(i)
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self._dataset_len is None:
|
||||
self._dataset_len = self._get_dataset_len()
|
||||
|
||||
return self._dataset_len
|
||||
|
||||
def __getitem__(self, index: int) -> I:
|
||||
item_data = self._get_item(index)
|
||||
index = self.indices[index]
|
||||
|
||||
return self._process_item_data(item_data, index)
|
||||
|
||||
def __iter__(self) -> Iterator[I]:
|
||||
"""
|
||||
Note: this method isn't technically needed given ``__getitem__`` is
|
||||
defined and we operate cleanly over integer indices 0..(N-1), so even
|
||||
without an explicit ``__iter__``, Python will fall back to a reliable
|
||||
iteration mechanism. We nevertheless implement the trivial logic below
|
||||
to convey intent and meet static type checks for iterables.
|
||||
"""
|
||||
|
||||
return (self[i] for i in range(len(self)))
|
||||
|
||||
|
||||
class CompositeBatchedDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||
"""
|
||||
Dataset class for wrapping individual datasets.
|
||||
|
||||
Note: because this remains a valid ``BatchedDataset``, we re-thread the
|
||||
generic type variables through the set of composed datasets. That is, they
|
||||
must have a common domain type ``Domain[U, R]``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[BatchedDataset[U, R, I]],
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
datasets: list of datasets
|
||||
"""
|
||||
|
||||
self.datasets = datasets
|
||||
|
||||
self._indices: list[int] | None = None
|
||||
self._dataset_len: int | None = None
|
||||
|
||||
self._item_psum: list[int] = []
|
||||
self._batch_psum: list[int] = []
|
||||
|
||||
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
|
||||
if not arr:
|
||||
return []
|
||||
|
||||
prefix_sum = [0] * (len(arr) + 1)
|
||||
for i in range(len(arr)):
|
||||
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
|
||||
|
||||
return prefix_sum
|
||||
|
||||
def _get_dataset_for_item(self, item_index: int) -> tuple[int, int]:
|
||||
dataset_index = bisect(self._item_psum, item_index) - 1
|
||||
index_in_dataset = item_index - self._item_psum[dataset_index]
|
||||
|
||||
return dataset_index, index_in_dataset
|
||||
|
||||
def _get_dataset_for_batch(self, batch_index: int) -> tuple[int, int]:
|
||||
dataset_index = bisect(self._batch_psum, batch_index) - 1
|
||||
index_in_dataset = batch_index - self._batch_psum[dataset_index]
|
||||
|
||||
return dataset_index, index_in_dataset
|
||||
|
||||
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||
index_pair = self._get_dataset_for_item(item_index)
|
||||
dataset_index, index_in_dataset = index_pair
|
||||
dataset = self.datasets[dataset_index]
|
||||
|
||||
return dataset._get_batch_for_item(index_in_dataset)
|
||||
|
||||
def _get_dataset_len(self) -> int:
|
||||
self.load_all()
|
||||
|
||||
dataset_batch_sizes = [len(dataset) for dataset in self.datasets]
|
||||
dataset_sizes = [dataset._num_batches for dataset in self.datasets]
|
||||
|
||||
# this method will only be ran once; set this instance var
|
||||
self._item_psum = self._compute_prefix_sum(dataset_batch_sizes)
|
||||
self._batch_psum = self._compute_prefix_sum(dataset_sizes)
|
||||
|
||||
return self._item_psum[-1]
|
||||
|
||||
def _process_batch_data(
|
||||
self,
|
||||
batch_data: R,
|
||||
batch_index: int,
|
||||
) -> list[I]:
|
||||
index_pair = self._get_dataset_for_item(batch_index)
|
||||
dataset_index, index_in_dataset = index_pair
|
||||
dataset = self.datasets[dataset_index]
|
||||
|
||||
return dataset._process_batch_data(batch_data, index_in_dataset)
|
||||
|
||||
def _process_item_data(self, item_data: I, item_index: int) -> I:
|
||||
index_pair = self._get_dataset_for_item(item_index)
|
||||
dataset_index, index_in_dataset = index_pair
|
||||
dataset = self.datasets[dataset_index]
|
||||
|
||||
return dataset._process_item_data(item_data, index_in_dataset)
|
||||
|
||||
def _get_item(self, item_index: int) -> I:
|
||||
item_index = self.indices[item_index]
|
||||
|
||||
index_pair = self._get_dataset_for_item(item_index)
|
||||
dataset_index, index_in_dataset = index_pair
|
||||
dataset = self.datasets[dataset_index]
|
||||
|
||||
return dataset._get_item(index_in_dataset)
|
||||
|
||||
def _get_batch(self, batch_index: int) -> list[I]:
|
||||
index_pair = self._get_dataset_for_item(batch_index)
|
||||
dataset_index, index_in_dataset = index_pair
|
||||
dataset = self.datasets[dataset_index]
|
||||
|
||||
return dataset._get_batch(index_in_dataset)
|
||||
|
||||
def load_all(
|
||||
self,
|
||||
num_workers: int | None = None
|
||||
) -> list[list[I]]:
|
||||
batches = []
|
||||
for dataset in self.datasets:
|
||||
batches.extend(dataset.load_all(num_workers))
|
||||
return batches
|
||||
|
||||
@property
|
||||
def pre_transform(self) -> list[Transform[I] | None]:
|
||||
return [dataset.pre_transform for dataset in self.datasets]
|
||||
|
||||
@pre_transform.setter
|
||||
def pre_transform(self, pre_transform: Transform[I]) -> None:
|
||||
for dataset in self.datasets:
|
||||
dataset.pre_transform = pre_transform
|
||||
|
||||
@property
|
||||
def post_transform(self) -> list[Transform[I] | None]:
|
||||
return [dataset.post_transform for dataset in self.datasets]
|
||||
|
||||
@post_transform.setter
|
||||
def post_transform(self, post_transform: Transform[I]) -> None:
|
||||
for dataset in self.datasets:
|
||||
dataset.post_transform = post_transform
|
||||
|
||||
|
||||
class BatchedSubset[U, R, I](BatchedDataset[U, R, I]):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: BatchedDataset[U, R, I],
|
||||
indices: list[int],
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.indices = indices
|
||||
|
||||
def _get_item(self, item_index: int) -> I:
|
||||
"""
|
||||
Subset indices are "reset" in its context. Simply passes through
|
||||
"""
|
||||
|
||||
return self.dataset._get_item(self.indices[item_index])
|
||||
|
||||
def _get_dataset_len(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||
"""
|
||||
Batched dataset where batches are equally sized.
|
||||
|
||||
Subclass from this base when you can count on the reference data being
|
||||
prepared with fixed size batches (up to the last batch), e.g., that which
|
||||
has been prepared with a `Packer`. This can greatly improve measurement
|
||||
time of the dataset size by preventing the need for reading all batch files
|
||||
upfront, and reduces the cost of identifying item batches from O(log n) to
|
||||
O(1).
|
||||
|
||||
Methods left for inheriting classes:
|
||||
|
||||
- ``_process_item_data()``: item processing
|
||||
- ``_process_batch_data()``: batch processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: Domain[U, R],
|
||||
**kwargs: Unpack[DatasetKwargs],
|
||||
) -> None:
|
||||
super().__init__(domain, **kwargs)
|
||||
|
||||
# determine batch size across dataset, along w/ possible partial final
|
||||
# batch
|
||||
bsize = rem = len(self.get_batch(self._num_batches - 1))
|
||||
if self._num_batches > 1:
|
||||
bsize, rem = len(self.get_batch(self._num_batches - 2)), bsize
|
||||
|
||||
self._batch_size: int = bsize
|
||||
self._batch_rem: int = rem
|
||||
|
||||
def _get_dataset_len(self) -> int:
|
||||
return self._batch_size * (self._num_batches - 1) + self._batch_rem
|
||||
|
||||
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||
return item_index // self._batch_size, item_index % self._batch_size
|
||||
|
||||
|
||||
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||
"""
|
||||
Batched dataset where batches have arbitrary size.
|
||||
|
||||
Methods left for inheriting classes:
|
||||
|
||||
- ``_process_item_data()``: item processing
|
||||
- ``_process_batch_data()``: batch processing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: Domain[U, R],
|
||||
**kwargs: Unpack[DatasetKwargs],
|
||||
) -> None:
|
||||
super().__init__(domain, **kwargs)
|
||||
|
||||
self._batch_size_psum: list[int] = []
|
||||
|
||||
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
|
||||
if not arr:
|
||||
return []
|
||||
|
||||
prefix_sum = [0] * (len(arr) + 1)
|
||||
for i in range(len(arr)):
|
||||
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
|
||||
|
||||
return prefix_sum
|
||||
|
||||
def _get_dataset_len(self) -> int:
|
||||
# type error below: no idea why this is flagged
|
||||
batches = self.load_all()
|
||||
batch_sizes = [len(batch) for batch in batches]
|
||||
|
||||
# this method will only be ran once; set this instance var
|
||||
self._batch_size_psum = self._compute_prefix_sum(batch_sizes)
|
||||
|
||||
return self._batch_size_psum[-1]
|
||||
|
||||
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
||||
batch_index = bisect(self._batch_size_psum, item_index) - 1
|
||||
index_in_batch = item_index - self._batch_size_psum[batch_index]
|
||||
|
||||
return batch_index, index_in_batch
|
||||
|
||||
|
||||
class SequenceDataset[I](HomogenousDataset[int, I, I]):
|
||||
"""
|
||||
Trivial dataset skeleton for sequence domains.
|
||||
|
||||
``I``-typed sequence items map directly to dataset items. To produce a
|
||||
fully concrete dataset, one still needs to define ``_process_item_data()``
|
||||
to map from ``I``-items to tuples.
|
||||
"""
|
||||
|
||||
domain: SequenceDomain[I]
|
||||
|
||||
def _process_batch_data(
|
||||
self,
|
||||
batch_data: I,
|
||||
batch_index: int,
|
||||
) -> list[I]:
|
||||
if self.pre_transform is not None:
|
||||
batch_data = self.pre_transform(batch_data)
|
||||
|
||||
return [batch_data]
|
||||
|
||||
|
||||
class TupleDataset[T](SequenceDataset[tuple[T, ...]]):
|
||||
"""
|
||||
Trivial sequence-of-tuples dataset.
|
||||
|
||||
This is the most straightforward line to a concrete dataset from a
|
||||
``BatchedDataset`` base class. That is: the underlying domain is a sequence
|
||||
whose items are mapped to single-item batches and are already tuples.
|
||||
"""
|
||||
|
||||
def _process_item_data(
|
||||
self,
|
||||
item_data: tuple[T, ...],
|
||||
item_index: int,
|
||||
) -> tuple[T, ...]:
|
||||
if self.post_transform is not None:
|
||||
item_data = self.post_transform(item_data)
|
||||
|
||||
return item_data
|
||||
0
trainlib/datasets/__init__.py
Normal file
0
trainlib/datasets/__init__.py
Normal file
179
trainlib/datasets/disk.py
Normal file
179
trainlib/datasets/disk.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from io import BytesIO
|
||||
from abc import abstractmethod
|
||||
from typing import Any, NamedTuple
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
from mema.dataset import HomogenousDataset
|
||||
from mema.domains.disk import DiskDomain
|
||||
|
||||
|
||||
class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
|
||||
"""
|
||||
The following line is to satisfy the type checker, which
|
||||
|
||||
1. Can't recognize an appropriately re-typed constructor arg like
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: DiskDomain,
|
||||
...
|
||||
): ...
|
||||
|
||||
This *does* match the parent generic for the U=Path, R=bytes context
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: Domain[U, R],
|
||||
...
|
||||
): ...
|
||||
|
||||
but the type checker doesn't see this.
|
||||
|
||||
2. "Lifted" type variables out of generics can't be used as upper bounds,
|
||||
at least not without throwing type checker warnings (thanks to PEP695).
|
||||
So I'm not allowed to have
|
||||
|
||||
```
|
||||
class BatchedDataset[U, R, D: Domain[U, R]]:
|
||||
...
|
||||
```
|
||||
|
||||
which could bring appropriately dynamic typing for ``Domain``s, but is
|
||||
not a sufficiently concrete upper bound.
|
||||
|
||||
So: we settle for a class-level type declaration, which despite not being
|
||||
technically appropriately scoped, it's not harming anything and satisfies
|
||||
``ty`` type checks downstream (e.g., when we access ``DiskDomain.root``.
|
||||
"""
|
||||
|
||||
domain: DiskDomain
|
||||
|
||||
|
||||
class PackedDataset(DiskDataset):
|
||||
"""
|
||||
Packed dataset.
|
||||
|
||||
Currently out of commission - not compatible with latest dataset
|
||||
definitions. Will require a zipped disk domain
|
||||
|
||||
Requires a specific dataset storage structure on the root data path:
|
||||
|
||||
<data-path>/data/*-i<batch-num>-b<batch-size>
|
||||
<data-path>/meta/*-i<batch-num>-b<batch-size>
|
||||
|
||||
That is, all data are compacted into core data (`data/`) and metadata
|
||||
(`meta/`) subdirectories. Compatible out-of-the-box with datasets written
|
||||
with a `Packer`.
|
||||
"""
|
||||
|
||||
def _get_uri_groups(self) -> list[tuple[Path, ...]]:
|
||||
data_root = Path(self.domain.root, "data")
|
||||
meta_root = Path(self.domain.root, "meta")
|
||||
|
||||
data_file_paths = data_root.iterdir()
|
||||
meta_file_paths = meta_root.iterdir()
|
||||
|
||||
return list(zip(data_file_paths, meta_file_paths, strict=True))
|
||||
|
||||
def _process_batch_data(
|
||||
self,
|
||||
batch_data: tuple[bytes, ...],
|
||||
batch_index: int,
|
||||
) -> list[tuple[bytes, ...]]:
|
||||
data_bytes, meta_bytes = batch_data
|
||||
|
||||
meta_batch = self._unpack_meta(meta_bytes)
|
||||
data_batch = self._unpack_data(data_bytes, meta_batch)
|
||||
|
||||
# zip up batch partial batch items into a single batch iterable
|
||||
# composed of item tuples
|
||||
batch_items = [
|
||||
(*ba, *bm) # pyre-ignore[60]
|
||||
for ba, bm in zip(data_batch, meta_batch, strict=True)
|
||||
]
|
||||
|
||||
# apply transform to batch items if provided
|
||||
if self.load_transform:
|
||||
batch_items = list(map(self.load_transform, batch_items))
|
||||
|
||||
return batch_items
|
||||
|
||||
@abstractmethod
|
||||
def _unpack_data(
|
||||
self,
|
||||
batch_data_bytes: bytes,
|
||||
batch_meta: list[tuple[Any, ...]],
|
||||
) -> list[tuple[Any, ...]]:
|
||||
"""
|
||||
Load and unpack batch data.
|
||||
|
||||
This method should be the inverse of an affiliated
|
||||
`Packer`'s `pack_data_bytes()`:
|
||||
|
||||
<data list> -> pack -> to bytes -> blob
|
||||
<data list> <- unpack <- from bytes <- blob
|
||||
|
||||
Returns:
|
||||
iterable of (partial) item tuples w/ data content
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _unpack_meta(self, batch_meta_bytes: bytes) -> list[tuple[Any, ...]]:
|
||||
"""
|
||||
Load and unpack batch metadata.
|
||||
|
||||
This method should be the inverse of an affiliated
|
||||
`Packer`'s `pack_meta_bytes()`:
|
||||
|
||||
<data list> -> pack -> to bytes -> blob
|
||||
<data list> <- unpack <- from bytes <- blob
|
||||
|
||||
Returns:
|
||||
iterable of (partial) item tuples w/ meta content
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ZippedDataset(DiskDataset):
|
||||
"""
|
||||
Dataset with samples stored in ZIP files.
|
||||
|
||||
This dataset base is primarily used as the type of input dataset for a
|
||||
`Packer` object. This is compatible with most raw dataset structures,
|
||||
reading down arbitrarily packaged ZIP files and "re-batching" during
|
||||
access.
|
||||
"""
|
||||
|
||||
item_header: tuple[str, ...] = ("bytes",)
|
||||
|
||||
def _get_uri_groups(self) -> list[tuple[str, ...]]:
|
||||
zip_file_paths: list[str] = [
|
||||
str(path)
|
||||
for path in Path(self.data_path).iterdir()
|
||||
if path.suffix == ".zip"
|
||||
]
|
||||
|
||||
# will just zip a single list yielding 1-tuples (to match type sig)
|
||||
return list(zip(zip_file_paths))
|
||||
|
||||
def _process_batch_data(
|
||||
self,
|
||||
batch_data: tuple[bytes, ...],
|
||||
batch_index: int,
|
||||
) -> list[tuple[bytes, ...]]:
|
||||
items = []
|
||||
batch_zip_file = ZipFile(BytesIO(batch_data[0]))
|
||||
|
||||
for zname in batch_zip_file.namelist():
|
||||
if Path(zname).suffix not in self.extensions:
|
||||
continue
|
||||
|
||||
with batch_zip_file.open(zname, "r") as zfile:
|
||||
items.append((zfile.read(), zname))
|
||||
|
||||
return list(items)
|
||||
|
||||
210
trainlib/datasets/memory.py
Normal file
210
trainlib/datasets/memory.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Leaving the following here in case we return for some specifics in the future.
|
||||
At an earlier stage, we had a few specifically typed datasets/domains for
|
||||
in-memory structures, but these are almost entirely redundant with the generic
|
||||
``SequenceDomain`` definition: the general retrieval and iteration behaviors
|
||||
there are fairly universal and can be type-tailored without new class
|
||||
definitions (unless you really want a new hierarchy).
|
||||
|
||||
The following were design notes on a dict/named tuple list-based dataset.
|
||||
|
||||
Dataset from list of (dict) records.
|
||||
|
||||
This is designed such that a "batch" is just a single record in the base
|
||||
list. Each URI group is a singleton index tuple, which grabs its single
|
||||
corresponding record in the domain when read.
|
||||
|
||||
One could alternatively have the entire record list be a single batch, and
|
||||
have a single URI group with N null references. The null reference would
|
||||
need to be the sole URI for the domain definition, and ``.read(null)``
|
||||
would always return the entire list. Everything will then be handled as
|
||||
expected: in ``get_item()``,
|
||||
|
||||
- ``._get_batch_for_item(item_index)`` maps to the singleton batch
|
||||
- ``.get_batch(batch_index)[index_in_batch]`` simply indexes directly in
|
||||
the record list
|
||||
|
||||
This is pretty unnatural, since now a batch as returned by
|
||||
``.read_resources()`` will be N references to the entire record list. It
|
||||
nevertheless sidesteps full list reconstruction and allows the propagation
|
||||
of direct indexing, so it has that in its corner.
|
||||
|
||||
Another viable approach (least preferred): have a single batch, but where
|
||||
URIs map to individual row indices. In some respects, this is the most
|
||||
intuitive interpretation of "batch" and how we'd map to items, but it's the
|
||||
*least efficient* because the batch logic of ``BatchedDataset`` will
|
||||
*reconstruct the entire record list*. In ``.read_resources()``, we have
|
||||
|
||||
```
|
||||
return tuple(self.domain.read(uri) for uri in uri_group)
|
||||
```
|
||||
|
||||
So, despite being a natural interpretation, it pulls apart the domain "too
|
||||
well" and rebuilds it in-house. This makes sense when resources are
|
||||
external (e.g., files on disk, data over network), but when already
|
||||
part of the Python process, that model is wasteful.
|
||||
|
||||
Note: inheriting datasets will need to implement an appropriate
|
||||
``_item_header``. This will be trivial for consistent datasets, since it'll
|
||||
just be the keys of any of the dict records.
|
||||
|
||||
Left to define:
|
||||
|
||||
+ ``_process_item_data()``: map T to final tuple
|
||||
|
||||
|
||||
|
||||
class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
|
||||
|
||||
domain: TupleListDomain[T]
|
||||
|
||||
def _process_batch_data(
|
||||
self,
|
||||
batch_data: T,
|
||||
batch_index: int,
|
||||
) -> list[T]:
|
||||
Produce collection of item tuples.
|
||||
|
||||
Note: no interaction with ``item_tuple`` is needed given batches are
|
||||
already singular ``item_tuple`` shaped items.
|
||||
|
||||
return [batch_data]
|
||||
"""
|
||||
|
||||
from typing import Unpack
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from trainlib.domain import SequenceDomain
|
||||
from trainlib.dataset import TupleDataset, DatasetKwargs
|
||||
|
||||
|
||||
class SlidingWindowDataset[T: Tensor](TupleDataset[T]):
|
||||
def __init__(
|
||||
self,
|
||||
domain: SequenceDomain[tuple[T, ...]],
|
||||
lookback: int,
|
||||
offset: int = 0,
|
||||
lookahead: int = 1,
|
||||
num_windows: int = 1,
|
||||
**kwargs: Unpack[DatasetKwargs],
|
||||
) -> None:
|
||||
self.lookback = lookback
|
||||
self.offset = offset
|
||||
self.lookahead = lookahead
|
||||
self.num_windows = num_windows
|
||||
|
||||
super().__init__(domain, **kwargs)
|
||||
|
||||
def _process_batch_data(
|
||||
self,
|
||||
batch_data: tuple[T, ...],
|
||||
batch_index: int,
|
||||
) -> list[tuple[T, ...]]:
|
||||
"""
|
||||
Backward pads first sequence over (lookback-1) length, and steps the
|
||||
remaining items forward by the lookahead.
|
||||
|
||||
Batch data:
|
||||
|
||||
(Tensor[C1, T], ..., Tensor[CN, T])
|
||||
|
||||
+ lookback determines window size; pad left to create lookback size
|
||||
with the first element at the right:
|
||||
|
||||
|-lookback-|
|
||||
[0 ... 0 T0]
|
||||
|
||||
`lookback` is strictly positive, unbounded.
|
||||
|
||||
+ offset shifts the first such window forward. `0 <= offset < L`;
|
||||
think of it as the number of additional non-padded items we
|
||||
slide into the window from the right. At its largest value of `L-1`,
|
||||
we'll have L-sized windows with `T0` as the leftmost element.
|
||||
|
||||
offset=2
|
||||
[0 ... T0 T1 T2]
|
||||
|---lookback---|
|
||||
|
||||
In effect, the index of the rightmost item of the first window
|
||||
will be equal to the value of `offset`. There are `T - offset`
|
||||
total possible windows over a given sequence (regardless of
|
||||
lookback).
|
||||
|
||||
+ lookahead determines the offset of the "label" slices from the
|
||||
first index, regardless of any value of `offset`.
|
||||
|
||||
lookahead=3
|
||||
[0 ... 0 T0] T1 T2 [T3]
|
||||
|-lookback-|
|
||||
|
||||
0 [0 .. T0 T1] T2 T3 [T4]
|
||||
|-lookback-|
|
||||
|
||||
There are `T - lookahead` allowed slices, assuming the lookahead
|
||||
exceeds the offset.
|
||||
|
||||
To get windows starting with the first index at the left: we first set
|
||||
out window size (call it L), determined by `lookback`. Then the
|
||||
rightmost index we want will be `L-1`, which determines our `offset`
|
||||
setting.
|
||||
|
||||
lookahead=L, offset=L-1
|
||||
[ T_0 ... T_{L-1} ]
|
||||
|
||||
To get a one-step lookahead in front of that rightmost item, the
|
||||
`lookahead` can be set to the index of the first label we want:
|
||||
|
||||
lookback=L, offset=L-1 lookahead=L
|
||||
[ T_0 ... T_{L-1} ] [ T_L ]
|
||||
|
||||
"""
|
||||
|
||||
if self.pre_transform is not None:
|
||||
batch_data = self.pre_transform(batch_data)
|
||||
|
||||
lb = self.lookback
|
||||
off = min(self.offset, lb-1)
|
||||
la = self.lookahead
|
||||
|
||||
ws = []
|
||||
for t in batch_data[:self.num_windows]:
|
||||
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
|
||||
# the amount of our offset, which in the extreme cases does no
|
||||
# padding.
|
||||
xip = F.pad(t, ((lb-1) - off, 0))
|
||||
|
||||
# extract sliding windows over the padded tensor
|
||||
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for
|
||||
# `lb`-sized windows. We turn (C_i, pad+T) shape into
|
||||
# (C_i, T-offset, lb), giving `T-offset` total `lb` windows.
|
||||
wi = xip.unfold(-1, lb, 1)
|
||||
|
||||
# (C_i, T-offset, lb) -> (T-offset, C_i, lb)
|
||||
wi = wi.permute(1, 0, 2)
|
||||
|
||||
# if lookahead exceeds offset, there are some windows for which we
|
||||
# won't be able to assign a lookahead label. Cut those off here
|
||||
if la - off > 0:
|
||||
wi = wi[:-(la-off)]
|
||||
|
||||
ws.append(wi)
|
||||
|
||||
ys = []
|
||||
for t in batch_data[self.num_windows:]:
|
||||
# tensors (C_i, T) shaped, align with lookahead, giving a
|
||||
# (C_i, T-lookahead)
|
||||
y = t[:, la:]
|
||||
|
||||
# (C_i, T-lookahead) -> (T-lookahead, C_i)
|
||||
y = y.permute(1, 0)
|
||||
|
||||
# cut off any elements if offset exceeds lookahead
|
||||
if off - la > 0:
|
||||
y = y[:-(off-la)]
|
||||
|
||||
ys.append(y)
|
||||
|
||||
return list(zip(*ws, *ys, strict=True))
|
||||
|
||||
85
trainlib/domain.py
Normal file
85
trainlib/domain.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Defines a knowledge domain. Wraps a Dataset / Simulator / Knowledge
|
||||
|
||||
Downstream exploration might include
|
||||
|
||||
- Calibrating Simulator / Knowledge with a Dataset
|
||||
- Amending Dataset with Simulator / Knowledge
|
||||
- Positioning Knowledge within Simulator context
|
||||
* Where to replace Simulator subsystem with Knowledge?
|
||||
|
||||
Other variations:
|
||||
|
||||
- Multi-fidelity simulators
|
||||
- Multi-scale models
|
||||
- Multi-system
|
||||
- Incomplete knowledge / divergence among sources
|
||||
|
||||
Questions:
|
||||
|
||||
- Should Simulator / Knowledge be unified as one (e.g., "Expert")
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Iterator, Sequence
|
||||
|
||||
|
||||
class Domain[U, R](Mapping[U, R]):
|
||||
"""
|
||||
Domain base class, generic to a URI type ``U`` and resource type ``R``.
|
||||
|
||||
Domains are just Mappings where the iterator behavior is specifically typed
|
||||
to range over keys (URIs). Defining a specific class here gives us a base
|
||||
for a nominal hierarchy, but functionally the Mapping core (sized iterables
|
||||
with accessors).
|
||||
"""
|
||||
|
||||
def __call__(self, uri: U) -> R:
|
||||
"""
|
||||
Get the resource for a given URI (call-based alias).
|
||||
"""
|
||||
|
||||
return self[uri]
|
||||
|
||||
def __getitem__(self, uri: U) -> R:
|
||||
"""
|
||||
Get the resource for a given URI.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def __iter__(self) -> Iterator[U]:
|
||||
"""
|
||||
Provide an iterator over domain URIs.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Measure the size the domain.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SequenceDomain[R](Domain[int, R]):
|
||||
"""
|
||||
Trivial domain implementation for wrapping sequences that can be seen as
|
||||
Mappings with 0-indexed keys.
|
||||
|
||||
Why define this? Domains provide iterators over their *keys*, sequences
|
||||
often iterate over *values*.
|
||||
"""
|
||||
|
||||
def __init__(self, sequence: Sequence[R]) -> None:
|
||||
self.sequence = sequence
|
||||
|
||||
def __getitem__(self, uri: int) -> R:
|
||||
return self.sequence[uri]
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
return iter(range(len(self.sequence)))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.sequence)
|
||||
0
trainlib/domains/__init__.py
Normal file
0
trainlib/domains/__init__.py
Normal file
37
trainlib/domains/disk.py
Normal file
37
trainlib/domains/disk.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterator
|
||||
|
||||
from mema.domain import Domain
|
||||
|
||||
|
||||
class DiskDomain(Domain[Path, bytes]):
|
||||
def __init__(
|
||||
self,
|
||||
root: Path,
|
||||
extensions: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
extensions: list of file extensions to filter for when determining
|
||||
data file paths to read. This is a whitelist, except when left
|
||||
as ``None`` (which is default), in which case all extensions
|
||||
all allowed.
|
||||
"""
|
||||
|
||||
self.root = root
|
||||
self.extensions = extensions
|
||||
|
||||
def __getitem__(self, uri: Path) -> bytes:
|
||||
return uri.read_bytes()
|
||||
|
||||
def __iter__(self) -> Iterator[Path]:
|
||||
return (
|
||||
path for path in self.root.iterdir()
|
||||
if path.is_file() and (
|
||||
self.extensions is None
|
||||
or path.suffix in self.extensions
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(1 for _ in iter(self))
|
||||
58
trainlib/domains/functional.py
Normal file
58
trainlib/domains/functional.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
|
||||
from trainlib.domain import Domain
|
||||
|
||||
|
||||
class SimulatorDomain[P, R](Domain[int, R]):
|
||||
"""
|
||||
Base simulator domain, generic to arbitrary callables.
|
||||
|
||||
Note: we don't store simulation results here; that's left to a downstream
|
||||
object, like a `BatchedDataset`, to cache if needed. We also don't subclass
|
||||
`SequenceDataset` because the item getter type doesn't align: we accept an
|
||||
`int` in the parameter list, but don't return the items directly from that
|
||||
collection (we transform them first).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
simulator: Callable[[P], R],
|
||||
parameters: Sequence[P],
|
||||
) -> None:
|
||||
self.simulator = simulator
|
||||
self.parameters = parameters
|
||||
|
||||
def __getitem__(self, uri: int) -> R:
|
||||
return self.simulator(self.parameters[uri])
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
return iter(range(len(self.parameters)))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.parameters)
|
||||
|
||||
|
||||
class SimulatorPredictiveDomain[P, R](Domain[int, tuple[R, ...]]):
|
||||
def __init__(
|
||||
self,
|
||||
simulator: Callable[[P], R],
|
||||
parameters: Sequence[P],
|
||||
predictives: Sequence[Callable[[R], R]],
|
||||
) -> None:
|
||||
self.simulator = simulator
|
||||
self.parameters = parameters
|
||||
self.predictives = predictives
|
||||
|
||||
def __getitem__(self, uri: int) -> tuple[R, ...]:
|
||||
sample = self.simulator(self.parameters[uri])
|
||||
|
||||
return (
|
||||
sample,
|
||||
*(p(sample) for p in self.predictives)
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
return iter(range(len(self.parameters)))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.parameters)
|
||||
195
trainlib/estimator.py
Normal file
195
trainlib/estimator.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Development note
|
||||
|
||||
I'd rather lay out bare args and kwargs in the estimator methods, but the
|
||||
allowed variance in subclasses makes this difficult. To ease the pain of typing
|
||||
coordination with datasets, `loss()` / `forward()` / `metrics()` specify their
|
||||
arguments through a TypedDict. This makes those signatures more portable, and
|
||||
can bound type variables in other places.
|
||||
|
||||
In theory, even the single tensor input base currently in place could be
|
||||
relaxed given the allowed variability in dataset / dataloader outputs. The
|
||||
default collate function in the PyTorch `DataLoader` source leaves types like
|
||||
strings and bytes unchanged, so not all kinds of data are batched into a
|
||||
`Sequence`, let alone a tensor. Nevertheless, estimators are `nn.Module`
|
||||
derivatives, so it's at minimum a safe assumption we'll need at least one
|
||||
tensor (or *should* have one).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Unpack, TypedDict
|
||||
from collections.abc import Generator
|
||||
|
||||
from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from trainlib.util.type import OptimizerKwargs
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EstimatorKwargs(TypedDict):
|
||||
inputs: Tensor
|
||||
|
||||
|
||||
class Estimator[Kw: EstimatorKwargs](nn.Module):
|
||||
"""
|
||||
Estimator base class.
|
||||
|
||||
All methods that raise ``NotImplementedErrors`` are directly invoked in the
|
||||
``Trainer.train(...)`` loop.
|
||||
|
||||
Note the flexibility afforded to the signatures of `forward()`, `loss()`,
|
||||
and `metrics()` methods, which should generally have identical sets of
|
||||
arguments in inheriting classes. The base class is generic to a type `K`,
|
||||
which should be a `TypedDict` for inheriting classes (despite not being
|
||||
enforceable as an upper bound) that reflects the keyword argument
|
||||
structure for these methods.
|
||||
|
||||
For instance, in a sequence prediction model with labels and masking, you
|
||||
might have something like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class PredictorKwargs(TypedDict, total=False):
|
||||
labels: list[Tensor]
|
||||
lengths: Tensor
|
||||
mask: Tensor
|
||||
|
||||
class SequencePredictor(Estimator[PredictorKwargs]):
|
||||
def forward(
|
||||
self,
|
||||
input: Tensor,
|
||||
**kwargs: Unpack[PredictorKwargs],
|
||||
) -> tuple[Tensor, ...]:
|
||||
...
|
||||
|
||||
def loss(
|
||||
self,
|
||||
input: Tensor,
|
||||
**kwargs: Unpack[PredictorKwargs],
|
||||
) -> Generator[Tensor]:
|
||||
...
|
||||
|
||||
def metrics(
|
||||
self,
|
||||
input: Tensor,
|
||||
**kwargs: Unpack[PredictorKwargs],
|
||||
) -> dict[str, float]:
|
||||
...
|
||||
|
||||
While `loss` and `metrics` should leverage the full set of keyword
|
||||
arguments, `forward` may not (e.g., in the example above, it shouldn't use
|
||||
`labels`).
|
||||
|
||||
Subclasses of `SequencePredictor` should then be generic over subtypes of
|
||||
`PredictorKwargs`.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> tuple[Tensor, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
def loss(
|
||||
self,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> Generator:
|
||||
"""
|
||||
Compute model loss for the given input.
|
||||
|
||||
Note that the loss is implemented as a generator to support
|
||||
multi-objective estimator setups. That is, losses can be yielded in
|
||||
sequence, allowing for a training loop to propagate model parameter
|
||||
updates before the next loss function is calculated. For instance, in a
|
||||
GAN-like setup, one might first emit the D-loss, update D parameters,
|
||||
then compute the G-loss (*depending* on the updated D parameters). Such
|
||||
a scheme is not otherwise possible without bringing the intermediate
|
||||
parameter update "in house" (breaking a separation of duties with the
|
||||
train loop).
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def metrics(
|
||||
self,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Compute metrics for the given input.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def optimizers(
|
||||
self,
|
||||
**kwargs: Unpack[OptimizerKwargs],
|
||||
) -> tuple[Optimizer, ...]:
|
||||
"""
|
||||
Get optimizers for the estimator to use in training loops.
|
||||
|
||||
Example providing a singular Adam-based optimizer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def optimizers(...):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
lr=1e-3,
|
||||
eps=1e-8,
|
||||
)
|
||||
return (optimizer,)
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def epoch_step(self) -> None:
|
||||
"""
|
||||
Step epoch-dependent model state.
|
||||
|
||||
This method should not include optimization of primary model
|
||||
parameters; that should be left to external optimizers. Instead, this
|
||||
method should step forward things like internal hyperparameter
|
||||
schedules in accordance with the expected call rate (e.g., every
|
||||
epoch).
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def epoch_write(
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> None:
|
||||
"""
|
||||
Write epoch-dependent Tensorboard items.
|
||||
|
||||
If implemented, this should supplement that which is provided in
|
||||
``metrics()``. Tensors provided as ``input`` should include raw
|
||||
training/validation data; examples include writing raw embeddings,
|
||||
canonical visualizations of the samples (e.g., previewing images), etc.
|
||||
|
||||
Parameters:
|
||||
input: batch of tensors from the current epoch
|
||||
writer: tensorboard writer instance
|
||||
step: current step in the optimization loop
|
||||
val: whether input is a validation sample
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def log_arch(self) -> None:
|
||||
"""
|
||||
Log Estimator architecture details.
|
||||
"""
|
||||
|
||||
logger.info(f"> Estimator :: {self.__class__.__name__}")
|
||||
|
||||
num_params = sum(
|
||||
p.numel() for p in self.parameters() if p.requires_grad
|
||||
)
|
||||
logger.info(f"| > # model parameters: {num_params}")
|
||||
0
trainlib/estimators/__init__.py
Normal file
0
trainlib/estimators/__init__.py
Normal file
491
trainlib/estimators/rnn.py
Normal file
491
trainlib/estimators/rnn.py
Normal file
@@ -0,0 +1,491 @@
|
||||
import logging
|
||||
from typing import Unpack, NotRequired
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from mema.estimator import Estimator, EstimatorKwargs
|
||||
from mema.util.type import OptimizerKwargs
|
||||
from mema.util.module import get_grad_norm
|
||||
from mema.estimators.tdnn import TDNNLayer
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RNNKwargs(EstimatorKwargs):
|
||||
inputs: Tensor
|
||||
labels: NotRequired[Tensor]
|
||||
|
||||
|
||||
class LSTM[K: RNNKwargs](Estimator[K]):
|
||||
"""
|
||||
Base RNN architecture.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
hidden_dim: int = 64,
|
||||
num_layers: int = 4,
|
||||
bidirectional: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
input_dim: dimensionality of the input
|
||||
output_dim: dimensionality of the output
|
||||
rnn_dim: dimensionality of each RNN layer output
|
||||
num_layers: number of LSTM layers pairs to use
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.dense_in = nn.Linear(input_dim, hidden_dim)
|
||||
self.lstm = nn.LSTM(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
|
||||
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
|
||||
self.dense_z = nn.Linear(lstm_out_dim, output_dim)
|
||||
|
||||
# weight initialization for LSTM layers
|
||||
def init_weights(m: nn.Module) -> None:
|
||||
if isinstance(m, nn.LSTM):
|
||||
for name, p in m.named_parameters():
|
||||
if "weight_ih" in name:
|
||||
nn.init.xavier_uniform_(p)
|
||||
elif "weight_hh" in name:
|
||||
nn.init.orthogonal_(p)
|
||||
elif "bias" in name:
|
||||
nn.init.zeros_(p)
|
||||
|
||||
self.apply(init_weights)
|
||||
|
||||
if verbose:
|
||||
self.log_arch()
|
||||
|
||||
def _clamp_rand(self, x: Tensor) -> Tensor:
|
||||
return torch.clamp(
|
||||
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
||||
min=-1.0,
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
# data shaped (B, C, T); map to (B, T, C)
|
||||
x = inputs.permute(0, 2, 1)
|
||||
x = torch.tanh(self.dense_in(x))
|
||||
x = self._clamp_rand(x)
|
||||
x, hidden = self.lstm(x)
|
||||
z = self.dense_z(x)
|
||||
|
||||
return z[:, -1, :], hidden
|
||||
|
||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||
predictions = self(**kwargs)[0]
|
||||
labels = kwargs["labels"]
|
||||
|
||||
yield F.mse_loss(predictions, labels)
|
||||
|
||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"grad_norm": get_grad_norm(self)
|
||||
}
|
||||
|
||||
def optimizers(
|
||||
self,
|
||||
**kwargs: Unpack[OptimizerKwargs],
|
||||
) -> tuple[Optimizer, ...]:
|
||||
"""
|
||||
"""
|
||||
|
||||
default_kwargs: Unpack[OptimizerKwargs] = {
|
||||
"lr": 1e-3,
|
||||
"eps": 1e-8,
|
||||
}
|
||||
opt_kwargs = {**default_kwargs, **kwargs}
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
**opt_kwargs,
|
||||
)
|
||||
|
||||
return (optimizer,)
|
||||
|
||||
def epoch_step(self) -> None:
|
||||
return None
|
||||
|
||||
def epoch_write(
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[K],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def log_arch(self) -> None:
|
||||
super().log_arch()
|
||||
|
||||
logger.info(f"| > {self.input_dim=}")
|
||||
logger.info(f"| > {self.hidden_dim=}")
|
||||
logger.info(f"| > {self.num_layers=}")
|
||||
logger.info(f"| > {self.output_dim=}")
|
||||
|
||||
|
||||
class MultiheadLSTMKwargs(EstimatorKwargs):
|
||||
inputs: Tensor
|
||||
labels: NotRequired[Tensor]
|
||||
auxiliary: NotRequired[Tensor]
|
||||
|
||||
|
||||
class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
hidden_dim: int = 64,
|
||||
num_layers: int = 4,
|
||||
bidirectional: bool = False,
|
||||
head_dims: list[int] | None = None,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
self.head_dims = head_dims if head_dims is not None else []
|
||||
|
||||
self.dense_in = nn.Linear(input_dim, hidden_dim)
|
||||
self.lstm = nn.LSTM(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
|
||||
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
|
||||
self.dense_z_out = nn.Linear(lstm_out_dim, output_dim)
|
||||
self.dense_z_heads = nn.ModuleList([
|
||||
nn.Linear(lstm_out_dim, head_dim)
|
||||
for head_dim in self.head_dims
|
||||
])
|
||||
|
||||
# weight initialization for LSTM layers
|
||||
def init_weights(m: nn.Module) -> None:
|
||||
if isinstance(m, nn.LSTM):
|
||||
for name, p in m.named_parameters():
|
||||
if "weight_ih" in name:
|
||||
nn.init.xavier_uniform_(p)
|
||||
elif "weight_hh" in name:
|
||||
nn.init.orthogonal_(p)
|
||||
elif "bias" in name:
|
||||
nn.init.zeros_(p)
|
||||
|
||||
self.apply(init_weights)
|
||||
|
||||
if verbose:
|
||||
self.log_arch()
|
||||
|
||||
def _clamp_rand(self, x: Tensor) -> Tensor:
|
||||
return torch.clamp(
|
||||
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
||||
min=-1.0,
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
# data shaped (B, C, T); map to (B, T, C)
|
||||
x = inputs.permute(0, 2, 1)
|
||||
x = torch.tanh(self.dense_in(x))
|
||||
x = self._clamp_rand(x)
|
||||
x, hidden = self.lstm(x)
|
||||
|
||||
z = self.dense_z_out(x)
|
||||
zs = torch.cat([head(x) for head in self.dense_z_heads], dim=-1)
|
||||
|
||||
return z[:, -1, :], zs[:, -1, :]
|
||||
|
||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||
pred, pred_aux = self(**kwargs)
|
||||
labels = kwargs["labels"]
|
||||
aux_labels = kwargs.get("auxiliary")
|
||||
|
||||
if aux_labels is None:
|
||||
yield F.mse_loss(pred, labels)
|
||||
else:
|
||||
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
|
||||
|
||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"grad_norm": get_grad_norm(self)
|
||||
}
|
||||
|
||||
def optimizers(
|
||||
self,
|
||||
**kwargs: Unpack[OptimizerKwargs],
|
||||
) -> tuple[Optimizer, ...]:
|
||||
"""
|
||||
"""
|
||||
|
||||
default_kwargs: Unpack[OptimizerKwargs] = {
|
||||
"lr": 1e-3,
|
||||
"eps": 1e-8,
|
||||
}
|
||||
opt_kwargs = {**default_kwargs, **kwargs}
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
**opt_kwargs,
|
||||
)
|
||||
|
||||
return (optimizer,)
|
||||
|
||||
def epoch_step(self) -> None:
|
||||
return None
|
||||
|
||||
def epoch_write(
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[K],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def log_arch(self) -> None:
|
||||
super().log_arch()
|
||||
|
||||
logger.info(f"| > {self.input_dim=}")
|
||||
logger.info(f"| > {self.hidden_dim=}")
|
||||
logger.info(f"| > {self.num_layers=}")
|
||||
logger.info(f"| > {self.output_dim=}")
|
||||
|
||||
|
||||
class ConvRNN[K: RNNKwargs](Estimator[K]):
|
||||
"""
|
||||
Base recurrent convolutional architecture.
|
||||
|
||||
Computes latents, initial states, and rate estimates from features and
|
||||
lambda parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
temporal_dim: int,
|
||||
gru_dim: int = 64,
|
||||
conv_dim: int = 96,
|
||||
num_layers: int = 4,
|
||||
conv_kernel_sizes: list[int] | None = None,
|
||||
conv_dilations: list[int] | None = None,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
input_dim: dimensionality of the input
|
||||
output_dim: dimensionality of the output
|
||||
gru_dim: dimensionality of each GRU layer output
|
||||
conv_dim: dimensionality of each conv layer output
|
||||
num_layers: number of gru-conv layer pairs to use
|
||||
conv_kernel_sizes: kernel sizes for conv layers
|
||||
conv_dilations: dilation settings for conv layers
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
self.gru_dim = gru_dim
|
||||
self.conv_dim = conv_dim
|
||||
self.num_layers = num_layers
|
||||
self.receptive_field = 0
|
||||
|
||||
self.conv_kernel_sizes: list[int]
|
||||
if conv_kernel_sizes is None:
|
||||
self.conv_kernel_sizes = [4] * num_layers
|
||||
else:
|
||||
self.conv_kernel_sizes = conv_kernel_sizes
|
||||
|
||||
self.conv_dilations: list[int]
|
||||
if conv_dilations is None:
|
||||
self.conv_dilations = [1] + [2] * (num_layers - 1)
|
||||
else:
|
||||
self.conv_dilations = conv_dilations
|
||||
|
||||
self._gru_layers: nn.ModuleList = nn.ModuleList()
|
||||
self._conv_layers: nn.ModuleList = nn.ModuleList()
|
||||
|
||||
layer_in_dim = gru_dim
|
||||
for i in range(self.num_layers):
|
||||
gru_layer = nn.GRU(layer_in_dim, gru_dim, batch_first=True)
|
||||
self._gru_layers.append(gru_layer)
|
||||
layer_in_dim += gru_dim
|
||||
|
||||
tdnn_layer = TDNNLayer(
|
||||
layer_in_dim,
|
||||
conv_dim,
|
||||
kernel_size=self.conv_kernel_sizes[i],
|
||||
dilation=self.conv_dilations[i],
|
||||
#pad=False,
|
||||
)
|
||||
self.receptive_field += tdnn_layer.receptive_field
|
||||
|
||||
self._conv_layers.append(tdnn_layer)
|
||||
layer_in_dim += conv_dim
|
||||
|
||||
# self.dense_in = nn.Linear(self.input_dim, gru_dim)
|
||||
self.dense_in = TDNNLayer(
|
||||
self.input_dim,
|
||||
gru_dim,
|
||||
kernel_size=1,
|
||||
pad=False
|
||||
)
|
||||
# will be (B, T, C), applies indep at each time step across channels
|
||||
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
|
||||
|
||||
# will be (B, C, T), applies indep at each time step across channels
|
||||
self.dense_z = TDNNLayer(
|
||||
layer_in_dim,
|
||||
self.output_dim,
|
||||
kernel_size=temporal_dim,
|
||||
pad=False,
|
||||
)
|
||||
|
||||
# weight initialization for GRU layers
|
||||
def init_weights(module: nn.Module) -> None:
|
||||
if isinstance(module, nn.GRU):
|
||||
for p in module.named_parameters():
|
||||
if p[0].startswith("weight_hh_"):
|
||||
nn.init.orthogonal_(p[1])
|
||||
|
||||
self.apply(init_weights)
|
||||
|
||||
if verbose:
|
||||
self.log_arch()
|
||||
|
||||
def _clamp_rand(self, x: Tensor) -> Tensor:
|
||||
return torch.clamp(
|
||||
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
||||
min=-1.0,
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
# embedding shaped (B, C, T)
|
||||
x = self._clamp_rand(torch.tanh(self.dense_in(inputs)))
|
||||
|
||||
# prepare shape (B, T, C) -- for GRU
|
||||
x = x.transpose(-2, -1)
|
||||
|
||||
for gru, conv in zip(self._gru_layers, self._conv_layers, strict=True):
|
||||
xg = self._clamp_rand(gru(x)[0])
|
||||
x = torch.cat([x, xg], -1)
|
||||
|
||||
xc = self._clamp_rand(conv(x.transpose(-2, -1)))
|
||||
xc = xc.transpose(-2, -1)
|
||||
x = torch.cat([x, xc], -1)
|
||||
|
||||
# z = self.dense_z(x)
|
||||
# z = z.transpose(-2, -1)
|
||||
|
||||
x = x.transpose(-2, -1)
|
||||
# map to (B, C, T)
|
||||
z = self.dense_z(x)
|
||||
|
||||
return (z,)
|
||||
|
||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||
predictions = self(**kwargs)[0]
|
||||
labels = kwargs["labels"]
|
||||
|
||||
# squeeze last dim; we've mapped T -> 1
|
||||
predictions = predictions.squeeze(-1)
|
||||
|
||||
yield F.mse_loss(predictions, labels, reduction="mean")
|
||||
|
||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"grad_norm": get_grad_norm(self)
|
||||
}
|
||||
|
||||
def optimizers(
|
||||
self,
|
||||
**kwargs: Unpack[OptimizerKwargs],
|
||||
) -> tuple[Optimizer, ...]:
|
||||
"""
|
||||
"""
|
||||
|
||||
default_kwargs: Unpack[OptimizerKwargs] = {
|
||||
"lr": 1e-3,
|
||||
"eps": 1e-8,
|
||||
}
|
||||
opt_kwargs = {**default_kwargs, **kwargs}
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
**opt_kwargs,
|
||||
)
|
||||
|
||||
return (optimizer,)
|
||||
|
||||
def epoch_step(self) -> None:
|
||||
return None
|
||||
|
||||
def epoch_write(
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[K],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def log_arch(self) -> None:
|
||||
super().log_arch()
|
||||
|
||||
logger.info(f"| > {self.input_dim=}")
|
||||
logger.info(f"| > {self.gru_dim=}")
|
||||
logger.info(f"| > {self.conv_dim=}")
|
||||
logger.info(f"| > {self.num_layers=}")
|
||||
logger.info(f"| > {self.conv_kernel_sizes=}")
|
||||
logger.info(f"| > {self.conv_dilations=}")
|
||||
logger.info(f"| > {self.receptive_field=}")
|
||||
logger.info(f"| > {self.output_dim=}")
|
||||
114
trainlib/estimators/tdnn.py
Normal file
114
trainlib/estimators/tdnn.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TDNNLayer(nn.Module):
|
||||
"""
|
||||
Time delay neural network layer.
|
||||
|
||||
Built on torch Conv1D layers, with additional support for automatic
|
||||
padding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_channels: int,
|
||||
output_channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: int = 1,
|
||||
lookahead: int = 0,
|
||||
pad: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Implements a fast TDNN layer via `torch.Conv1d`.
|
||||
|
||||
Note that we're restricted to the kernel shapes producible by `Conv1d`
|
||||
objects, implying (primarily) that the kernel must be symmetric and
|
||||
have equal spacing.
|
||||
|
||||
For example, the symmetric but non-uniform context [-3, -2, 0, +2, +3]
|
||||
cannot be represented, while [-6, -3, 0, 3, 6] can. A few other kernel
|
||||
examples:
|
||||
|
||||
kernel_size=3; dilation=1 -> [-1, 0, 1]
|
||||
kernel_size=3; dilation=3 -> [-3, 0, 3]
|
||||
kernel_size=4; dilation=2 -> [-3, -1, 1, 3]
|
||||
|
||||
By default, the TDNN layer left pads the input to ensure the output has
|
||||
the same sequence length as the original input. For example, with a
|
||||
kernel size of 3 (dilation of 1) and sequence length of T=3, the
|
||||
sequence will be left padded with 2 zeros:
|
||||
|
||||
[0, 0, 1, 1, 1] -> [x1, x2, x3]
|
||||
|
||||
If a lookahead is specified, some number of those left zeros will be
|
||||
moved to the right. If lookahead=1, for instance, indicating 1
|
||||
additional "future frame" of context, padding will look like
|
||||
|
||||
[0, 1, 1, 1, 0] -> [x1, x2, x3]
|
||||
|
||||
The output x_i now sees through time step i+1.
|
||||
|
||||
Parameters:
|
||||
input_channels: number of input channels
|
||||
output_channels: number of channels produced by the temporal
|
||||
convolution
|
||||
kernel_size: total size of the kernel
|
||||
dilation: dilation of receptive field, i.e., the size of the gaps
|
||||
lookahead: number of allowed lookahead frames
|
||||
pad: whether the input should be padded, producing an output with
|
||||
the same sequence length as the input
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.td_conv: nn.Module = weight_norm(
|
||||
nn.Conv1d(
|
||||
input_channels,
|
||||
output_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
)
|
||||
)
|
||||
|
||||
self.pad: bool = pad
|
||||
self.lookahead = lookahead
|
||||
self.receptive_field: int = (kernel_size - 1) * dilation + 1
|
||||
|
||||
assert (
|
||||
self.lookahead < self.receptive_field
|
||||
), "Lookahead cannot exceed receptive field"
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Dimension definitions:
|
||||
|
||||
- B: batch size
|
||||
- Di: input dimension at single time step (aka input channels)
|
||||
- Do: output dimension at single time step (aka output channels)
|
||||
- T: sequence length
|
||||
|
||||
Parameters:
|
||||
x: input tensor, shaped [B, Di, T] (optionally w/o a batch dim)
|
||||
|
||||
Returns:
|
||||
tensor shaped [B, Do, T-kernel_size]
|
||||
"""
|
||||
|
||||
# pad according to receptive field and lookahead s.t. output
|
||||
# shape [B, *, T]
|
||||
if self.pad:
|
||||
x = F.pad(
|
||||
x,
|
||||
(self.receptive_field - self.lookahead - 1, self.lookahead)
|
||||
)
|
||||
|
||||
return self.td_conv(x)
|
||||
509
trainlib/trainer.py
Normal file
509
trainlib/trainer.py
Normal file
@@ -0,0 +1,509 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from copy import deepcopy
|
||||
from typing import Any, Self
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch import cuda, Tensor
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from trainlib.dataset import BatchedDataset
|
||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||
from trainlib.transform import Transform
|
||||
from trainlib.util.type import (
|
||||
SplitKwargs,
|
||||
LoaderKwargs,
|
||||
BalanceKwargs,
|
||||
)
|
||||
from trainlib.util.module import ModelWrapper
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer[I, K: EstimatorKwargs]:
|
||||
"""
|
||||
Training interface for updating ``Estimators`` with ``Datasets``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
estimator: Estimator[K],
|
||||
device: str | None = None,
|
||||
chkpt_dir: str = "chkpt/",
|
||||
tblog_dir: str = "tblog/",
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
estimator: `Estimator` model object
|
||||
device: device on which to carry out training
|
||||
"""
|
||||
|
||||
self.device: str
|
||||
if device is None:
|
||||
self.device = "cuda" if cuda.is_available() else "cpu"
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
logger.info(f"> Trainer device: {self.device}")
|
||||
if self.device.startswith("cuda"):
|
||||
if torch.cuda.is_available():
|
||||
# extra cuda details
|
||||
logger.info(f"| > {cuda.device_count()=}")
|
||||
logger.info(f"| > {cuda.current_device()=}")
|
||||
logger.info(f"| > {cuda.get_device_name()=}")
|
||||
logger.info(f"| > {cuda.get_device_capability()=}")
|
||||
|
||||
# memory info (in GB)
|
||||
gb = 1024**3
|
||||
memory_allocated = cuda.memory_allocated() / gb
|
||||
memory_reserved = cuda.memory_reserved() / gb
|
||||
memory_total = cuda.get_device_properties(0).total_memory / gb
|
||||
|
||||
logger.info("| > CUDA memory:")
|
||||
logger.info(f"| > {memory_total=:.2f}GB")
|
||||
logger.info(f"| > {memory_reserved=:.2f}GB")
|
||||
logger.info(f"| > {memory_allocated=:.2f}GB")
|
||||
else:
|
||||
logger.warning("| > CUDA device specified but not available")
|
||||
else:
|
||||
logger.info("| > Using CPU device - no additional device info")
|
||||
|
||||
self.estimator = estimator
|
||||
self.estimator.to(self.device)
|
||||
|
||||
self.chkpt_dir = Path(chkpt_dir).resolve()
|
||||
self.tblog_dir = Path(tblog_dir).resolve()
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Set base tracking parameters.
|
||||
"""
|
||||
|
||||
self._step: int = 0
|
||||
self._epoch: int = 0
|
||||
self._summary: dict[str, list[tuple[float, int]]] = defaultdict(list)
|
||||
|
||||
self._val_loss = float("inf")
|
||||
self._best_val_loss = float("inf")
|
||||
self._stagnant_epochs = 0
|
||||
self._best_model_state_dict: dict[str, Any] = {}
|
||||
|
||||
def train(
|
||||
self,
|
||||
dataset: BatchedDataset[..., ..., I],
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
lr: float = 1e-3,
|
||||
eps: float = 1e-8,
|
||||
max_grad_norm: float | None = None,
|
||||
max_epochs: int = 10,
|
||||
stop_after_epochs: int = 5,
|
||||
batch_size: int = 256,
|
||||
val_frac: float = 0.1,
|
||||
train_transform: Transform | None = None,
|
||||
val_transform: Transform | None = None,
|
||||
dataset_split_kwargs: SplitKwargs | None = None,
|
||||
dataset_balance_kwargs: BalanceKwargs | None = None,
|
||||
dataloader_kwargs: LoaderKwargs | None = None,
|
||||
summarize_every: int = 1,
|
||||
chkpt_every: int = 1,
|
||||
resume_latest: bool = False,
|
||||
summary_writer: SummaryWriter | None = None,
|
||||
) -> Estimator:
|
||||
"""
|
||||
Note: this method attempts to implement a general scheme for passing
|
||||
needed items to the estimator's loss function from the dataloader. The
|
||||
abstract `Estimator` base only requires the model output be provided
|
||||
for any given loss calculation, but concrete estimators will often
|
||||
require additional arguments (e.g., labels or length masks, as
|
||||
is the case with sequential models). In any case, this method defers
|
||||
any further logic to the `loss` method of the underlying estimator, so
|
||||
one should take care to synchronize the sample structure with `dataset`
|
||||
to match that expected by `self.estimator.loss(...)`.
|
||||
|
||||
On batch_estimator_map:
|
||||
|
||||
Dataloader collate functions are responsible for mapping a collection
|
||||
of items into an item of collections, roughly speaking. If items are
|
||||
tuples of tensors,
|
||||
|
||||
[
|
||||
( [1, 1], [1, 1] ),
|
||||
( [2, 2], [2, 2] ),
|
||||
( [3, 3], [3, 3] ),
|
||||
]
|
||||
|
||||
the collate function maps back into the item skeleton, producing a
|
||||
single tuple of (stacked) tensors
|
||||
|
||||
( [[1, 1],
|
||||
[2, 2],
|
||||
[3, 3]],
|
||||
|
||||
[[1, 1],
|
||||
[2, 2],
|
||||
[3, 3]] )
|
||||
|
||||
This function should map from batches (which should be *item shaped*,
|
||||
i.e., have an `I` skeleton, even if stacked items may be different on
|
||||
the inside) into estimator keyword arguments (type `K`).
|
||||
|
||||
Parameters:
|
||||
lr: learning rate (default: 1e-3)
|
||||
eps: adam EPS (default: 1e-8)
|
||||
max_epochs: maximum number of training epochs
|
||||
stop_after_epochs: number of epochs with stagnant validation losses
|
||||
to allow before early stopping. If training stops earlier, the
|
||||
parameters for the best recorded validation score are loaded
|
||||
into the estimator before the method returns. If
|
||||
`stop_after_epochs >= max_epochs`, the estimator will train
|
||||
over all epochs and return as is, irrespective of validation
|
||||
scores.
|
||||
batch_size: size of batch to use when training on the provided
|
||||
dataset
|
||||
val_split_frac: fraction of dataset to use for validation
|
||||
chkpt_every: how often model checkpoints should be saved
|
||||
resume_latest: resume training from the latest available checkpoint
|
||||
in the `chkpt_dir`
|
||||
"""
|
||||
|
||||
logger.info("> Begin train loop:")
|
||||
logger.info(f"| > {lr=}")
|
||||
logger.info(f"| > {eps=}")
|
||||
logger.info(f"| > {max_epochs=}")
|
||||
logger.info(f"| > {batch_size=}")
|
||||
logger.info(f"| > {val_frac=}")
|
||||
logger.info(f"| > {chkpt_every=}")
|
||||
logger.info(f"| > {resume_latest=}")
|
||||
logger.info(f"| > with device: {self.device}")
|
||||
logger.info(f"| > core count: {os.cpu_count()}")
|
||||
|
||||
writer: SummaryWriter
|
||||
dir_prefix = str(int(time.time()))
|
||||
if summary_writer is None:
|
||||
writer = SummaryWriter(f"{self.tblog_dir}")
|
||||
else:
|
||||
writer = summary_writer
|
||||
|
||||
train_loader, val_loader = self.get_dataloaders(
|
||||
dataset,
|
||||
batch_size,
|
||||
val_frac=val_frac,
|
||||
train_transform=train_transform,
|
||||
val_transform=val_transform,
|
||||
dataset_split_kwargs=dataset_split_kwargs,
|
||||
dataset_balance_kwargs=dataset_balance_kwargs,
|
||||
dataloader_kwargs=dataloader_kwargs,
|
||||
)
|
||||
|
||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||
|
||||
self._step = 0
|
||||
self._epoch = 1 # start from 1 for logging convenience
|
||||
while self._epoch <= max_epochs and not self._converged(
|
||||
self._epoch, stop_after_epochs
|
||||
):
|
||||
print(f"Training epoch {self._epoch}/{max_epochs}...")
|
||||
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...")
|
||||
|
||||
epoch_start_time = time.time()
|
||||
train_loss_sums = []
|
||||
self.estimator.train()
|
||||
with tqdm(train_loader, unit="batch") as train_epoch:
|
||||
for i, batch_data in enumerate(train_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
|
||||
# one-time logging
|
||||
if self._step == 0:
|
||||
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=False,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
train_losses = self.estimator.loss(**est_kwargs)
|
||||
train_loss_items = []
|
||||
for o_idx, optimizer in enumerate(optimizers):
|
||||
optimizer.zero_grad()
|
||||
train_loss = next(train_losses)
|
||||
|
||||
if len(train_loss_sums) <= o_idx:
|
||||
train_loss_sums.append(0.0)
|
||||
|
||||
train_loss_item = train_loss.item()
|
||||
train_loss_sums[o_idx] += train_loss_item
|
||||
train_loss_items.append(train_loss_item)
|
||||
|
||||
train_loss.backward()
|
||||
|
||||
# clip gradients for optimizer's parameters
|
||||
if max_grad_norm is not None:
|
||||
opt_params = self._get_optimizer_parameters(optimizer)
|
||||
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
self._step += len(inputs)
|
||||
|
||||
for train_loss_item, train_loss_sum in zip(
|
||||
train_loss_items,
|
||||
train_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("train_loss", train_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"train_{metric_name}", metric_value)
|
||||
|
||||
self.estimator.epoch_step()
|
||||
|
||||
for li, train_loss_sum in enumerate(train_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
||||
)
|
||||
|
||||
if val_frac > 0:
|
||||
val_loss_sums = []
|
||||
self.estimator.eval()
|
||||
with tqdm(val_loader, unit="batch") as val_epoch:
|
||||
for i, batch_data in enumerate(val_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=True,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
val_losses = self.estimator.loss(**est_kwargs)
|
||||
val_loss_items = []
|
||||
for o_idx in range(len(optimizers)):
|
||||
val_loss = next(val_losses)
|
||||
|
||||
if len(val_loss_sums) <= o_idx:
|
||||
val_loss_sums.append(0.0)
|
||||
|
||||
val_loss_item = val_loss.item()
|
||||
val_loss_sums[o_idx] += val_loss_item
|
||||
val_loss_items.append(val_loss_item)
|
||||
|
||||
for val_loss_item, val_loss_sum in zip(
|
||||
val_loss_items,
|
||||
val_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("val_loss", val_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"val_{metric_name}", metric_value)
|
||||
|
||||
for li, val_loss_sum in enumerate(val_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
||||
)
|
||||
|
||||
# convergence of multiple losses may be ambiguous
|
||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||
|
||||
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
|
||||
|
||||
if self._epoch % summarize_every == 0:
|
||||
self._summarize(writer, self._epoch)
|
||||
|
||||
# save checkpoint
|
||||
if self._epoch % chkpt_every == 0:
|
||||
self.save_model(
|
||||
self._epoch, self.chkpt_dir, dir_prefix
|
||||
)
|
||||
|
||||
self._epoch += 1
|
||||
|
||||
return self.estimator
|
||||
|
||||
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
||||
converged = False
|
||||
|
||||
if epoch == 1 or self._val_loss < self._best_val_loss:
|
||||
self._best_val_loss = self._val_loss
|
||||
self._stagnant_epochs = 0
|
||||
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
||||
else:
|
||||
self._stagnant_epochs += 1
|
||||
|
||||
if self._stagnant_epochs >= stop_after_epochs:
|
||||
self.estimator.load_state_dict(self._best_model_state_dict)
|
||||
converged = True
|
||||
|
||||
return converged
|
||||
|
||||
@staticmethod
|
||||
def get_dataloaders(
|
||||
dataset: BatchedDataset,
|
||||
batch_size: int,
|
||||
val_frac: float = 0.1,
|
||||
train_transform: Transform | None = None,
|
||||
val_transform: Transform | None = None,
|
||||
dataset_split_kwargs: SplitKwargs | None = None,
|
||||
dataset_balance_kwargs: BalanceKwargs | None = None,
|
||||
dataloader_kwargs: LoaderKwargs | None = None,
|
||||
) -> tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create training and validation dataloaders for the provided dataset.
|
||||
"""
|
||||
|
||||
if dataset_split_kwargs is None:
|
||||
dataset_split_kwargs = {}
|
||||
|
||||
if dataset_balance_kwargs is not None:
|
||||
dataset.balance(**dataset_balance_kwargs)
|
||||
|
||||
if val_frac <= 0:
|
||||
dataset.post_transform = train_transform
|
||||
train_loader_kwargs: LoaderKwargs = {
|
||||
"batch_size": min(batch_size, len(dataset)),
|
||||
"num_workers": 0,
|
||||
"shuffle": True,
|
||||
}
|
||||
if dataloader_kwargs is not None:
|
||||
train_loader_kwargs: LoaderKwargs = {
|
||||
**train_loader_kwargs,
|
||||
**dataloader_kwargs
|
||||
}
|
||||
|
||||
return (
|
||||
DataLoader(dataset, **train_loader_kwargs),
|
||||
DataLoader(Dataset())
|
||||
)
|
||||
|
||||
train_dataset, val_dataset = dataset.split(
|
||||
[1 - val_frac, val_frac],
|
||||
**dataset_split_kwargs,
|
||||
)
|
||||
|
||||
# Dataset.split() returns light Subset objects of shallow copies of the
|
||||
# underlying dataset; can change the transform attribute of both splits
|
||||
# w/o overwriting
|
||||
train_dataset.post_transform = train_transform
|
||||
val_dataset.post_transform = val_transform
|
||||
|
||||
train_loader_kwargs: LoaderKwargs = {
|
||||
"batch_size": min(batch_size, len(train_dataset)),
|
||||
"num_workers": 0,
|
||||
"shuffle": True,
|
||||
}
|
||||
val_loader_kwargs: LoaderKwargs = {
|
||||
"batch_size": min(batch_size, len(val_dataset)),
|
||||
"num_workers": 0,
|
||||
"shuffle": True, # shuffle to prevent homogeneous val batches
|
||||
}
|
||||
|
||||
if dataloader_kwargs is not None:
|
||||
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
|
||||
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
|
||||
|
||||
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
|
||||
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
def _summarize(self, writer: SummaryWriter, epoch: int) -> None:
|
||||
"""
|
||||
Flush the training summary to the TB summary writer.
|
||||
"""
|
||||
|
||||
summary_values = defaultdict(list)
|
||||
for name, records in self._summary.items():
|
||||
for value, step in records:
|
||||
writer.add_scalar(name, value, step)
|
||||
summary_values[name].append(value)
|
||||
|
||||
print(f"==== Epoch [{epoch}] summary ====")
|
||||
for name, values in summary_values.items():
|
||||
mean_value = torch.tensor(values).mean().item()
|
||||
print(f"> ({len(values)}) {name} :: {mean_value:.2f}")
|
||||
|
||||
writer.flush()
|
||||
self._summary = defaultdict(list)
|
||||
|
||||
def _get_optimizer_parameters(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
) -> list[Tensor]:
|
||||
return [
|
||||
param
|
||||
for param_group in optimizer.param_groups
|
||||
for param in param_group["params"]
|
||||
if param.grad is not None
|
||||
]
|
||||
|
||||
def _add_summary_item(self, name: str, value: float) -> None:
|
||||
self._summary[name].append((value, self._step))
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
epoch: int,
|
||||
chkpt_dir: str | Path,
|
||||
dir_prefix: str,
|
||||
) -> None:
|
||||
"""
|
||||
Save a model checkpoint.
|
||||
"""
|
||||
|
||||
model_buff = BytesIO()
|
||||
torch.save(self.estimator.state_dict(), model_buff)
|
||||
model_buff.seek(0)
|
||||
|
||||
model_class = self.estimator.__class__.__name__
|
||||
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
|
||||
|
||||
chkpt_dir = Path(chkpt_dir, dir_prefix)
|
||||
chkpt_path = Path(chkpt_dir, chkpt_name)
|
||||
|
||||
chkpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
chkpt_path.write_bytes(model_buff.getvalue())
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
epoch: int,
|
||||
chkpt_dir: str,
|
||||
) -> None:
|
||||
"""
|
||||
Load a model checkpoint from a given epoch.
|
||||
|
||||
Note that this assumes the model was saved via `Trainer.save_model()`,
|
||||
and the estimator provided to this `Trainer` instance matches the
|
||||
architecture of the checkpoint model being loaded.
|
||||
"""
|
||||
|
||||
model_class = self.estimator.__class__.__name__
|
||||
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
|
||||
chkpt_path = Path(chkpt_dir, chkpt_name)
|
||||
|
||||
model_buff = BytesIO(chkpt_path.read_bytes())
|
||||
model_buff.seek(0)
|
||||
|
||||
model_dict = torch.load(model_buff, weights_only=True)
|
||||
self.estimator.load_state_dict(model_dict)
|
||||
11
trainlib/transform.py
Normal file
11
trainlib/transform.py
Normal file
@@ -0,0 +1,11 @@
|
||||
class Transform[I]:
|
||||
"""
|
||||
Dataset transform base class.
|
||||
|
||||
In places that directly reference a base ``Transform[I]``, a hint
|
||||
``Callable[[I], I]`` would suffice. This class exists to allow nominal
|
||||
checks for purpose-built transforms.
|
||||
"""
|
||||
|
||||
def __call__(self, item: I) -> I:
|
||||
raise NotImplementedError
|
||||
0
trainlib/utils/__init__.py
Normal file
0
trainlib/utils/__init__.py
Normal file
54
trainlib/utils/job.py
Normal file
54
trainlib/utils/job.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import concurrent
|
||||
from concurrent.futures import Future, as_completed
|
||||
|
||||
from tqdm import tqdm
|
||||
from colorama import Fore, Style
|
||||
|
||||
from mema.util.text import color_text
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def process_futures(
|
||||
futures: list[Future],
|
||||
desc: str | None = None,
|
||||
unit: str | None = None,
|
||||
) -> None:
|
||||
if desc is None:
|
||||
desc = "Awaiting futures"
|
||||
|
||||
if unit is None:
|
||||
unit = "it"
|
||||
|
||||
success = 0
|
||||
cancelled = 0
|
||||
errored = 0
|
||||
submitted = len(futures)
|
||||
progress_bar = tqdm(
|
||||
total=len(futures),
|
||||
desc=f"{desc} [submitted {len(futures)}]",
|
||||
unit=unit,
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
success += 1
|
||||
except concurrent.futures.CancelledError as e:
|
||||
cancelled += 1
|
||||
logger.error(f'Future cancelled; "{e}"')
|
||||
except Exception as e:
|
||||
errored += 1
|
||||
logger.warning(f'Future failed with unknown exception "{e}"')
|
||||
|
||||
suc_txt = color_text(f"{success}", Fore.GREEN)
|
||||
can_txt = color_text(f"{cancelled}", Fore.YELLOW)
|
||||
err_txt = color_text(f"{errored}", Fore.RED)
|
||||
tot_txt = color_text(f"{success+cancelled+errored}", Style.BRIGHT)
|
||||
progress_bar.set_description(
|
||||
f"{desc} [{tot_txt} / {submitted} | {suc_txt} {can_txt} {err_txt}]"
|
||||
)
|
||||
progress_bar.update(n=1)
|
||||
|
||||
progress_bar.close()
|
||||
25
trainlib/utils/module.py
Normal file
25
trainlib/utils/module.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
def __init__(self, model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
#def forward(self, inputs, **kwargs):
|
||||
#return self.model(**{"inputs": inputs, **kwargs})
|
||||
def forward(self, kwargs):
|
||||
return self.model(**kwargs)
|
||||
|
||||
|
||||
def get_grad_norm(model: nn.Module, p: int = 2) -> float:
|
||||
norm = 0
|
||||
for param in model.parameters():
|
||||
if not param.requires_grad or param.grad is None:
|
||||
continue
|
||||
grad_item = torch.abs(param.grad).pow(p).sum().item()
|
||||
norm += float(grad_item)
|
||||
|
||||
return norm ** (1 / p)
|
||||
|
||||
8
trainlib/utils/text.py
Normal file
8
trainlib/utils/text.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from colorama import Style
|
||||
|
||||
|
||||
def color_text(text: str, *colorama_args: Any) -> str:
|
||||
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"
|
||||
|
||||
52
trainlib/utils/type.py
Normal file
52
trainlib/utils/type.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, TypedDict
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from torch import Tensor
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from mema.dataset import BatchedDataset
|
||||
|
||||
|
||||
class LoaderKwargs(TypedDict, total=False):
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
sampler: Sampler | Iterable | None
|
||||
batch_sampler: Sampler[list] | Iterable[list] | None
|
||||
num_workers: int
|
||||
collate_fn: Callable[[list], Any]
|
||||
pin_memory: bool
|
||||
drop_last: bool
|
||||
timeout: float
|
||||
worker_init_fn: Callable[[int], None]
|
||||
multiprocessing_context: object
|
||||
generator: object
|
||||
prefetch_factor: int
|
||||
persistent_workers: bool
|
||||
pin_memory_device: str
|
||||
in_order: bool
|
||||
|
||||
|
||||
class SplitKwargs(TypedDict, total=False):
|
||||
dataset: BatchedDataset | None
|
||||
by_attr: str | list[str | None] | None
|
||||
shuffle_strata: bool
|
||||
|
||||
|
||||
class BalanceKwargs(TypedDict, total=False):
|
||||
by_attr: str | list[str | None] | None
|
||||
split_min_sizes: list[int] | None
|
||||
split_max_sizes: list[int] | None
|
||||
shuffle_strata: bool
|
||||
|
||||
|
||||
class OptimizerKwargs(TypedDict, total=False):
|
||||
lr: float | Tensor
|
||||
betas: tuple[float | Tensor, float | Tensor]
|
||||
eps: float
|
||||
weight_decay: float
|
||||
amsgrad: bool
|
||||
maximize: bool
|
||||
foreach: bool | None
|
||||
capturable: bool
|
||||
differentiable: bool
|
||||
fused: bool | None
|
||||
Reference in New Issue
Block a user