reformat docstrings for sphinx

This commit is contained in:
2026-03-05 01:36:40 -08:00
parent 805262dfc4
commit faeef9c72a
3 changed files with 117 additions and 102 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "trainlib" name = "trainlib"
version = "0.1.0" version = "0.1.1"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13" requires-python = ">=3.13"
authors = [ authors = [

View File

@@ -1,5 +1,5 @@
""" """
Marginalizing out the modality layer: .. admonition:: Marginalizing out the modality layer
With ``domain`` being an instance variable, one possible interpretation of With ``domain`` being an instance variable, one possible interpretation of
the object structures here is that one could completely abstract away the object structures here is that one could completely abstract away
@@ -58,64 +58,71 @@ Marginalizing out the modality layer:
particular case of ``_process_batch_data()``, it feels much better when particular case of ``_process_batch_data()``, it feels much better when
it's on the inside.) it's on the inside.)
Holding: .. admonition:: Holding area
@abstractmethod
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
Get URI groups for each batch.
If there's more than one URI per batch (e.g., a data file and a .. code-block:: python
metadata file), zip the URIs such that we have a tuple of URIs per
batch.
Note that this effectively defines the index style over batches in the @abstractmethod
attached domain. We get an ``int -> tuple[U, ...]`` map that turns def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
batch indices into URIs that can be read under the domain. Get URI groups for each batch.
``get_batch()`` turns an integer index into its corresponding
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
the tuple, treating them as providers of batched data.
``_read_resources()`` passes through to the attached domain logic,
which, although common, need not supply an explicit iterable of batch
items: we just access items with ``__getitem__()`` and may ask for
``__len__``. So the returned URI group collection (this method) does
need to be iterable to measure the number of batches, but the batch
objects that are ultimately produced by these URI groups need not be
iterables themselves.
raise NotImplementedError If there's more than one URI per batch (e.g., a data file and a
metadata file), zip the URIs such that we have a tuple of URIs per
batch.
def _read_resources( Note that this effectively defines the index style over batches in
self, the attached domain. We get an ``int -> tuple[U, ...]`` map that
uri_group: tuple[U, ...], turns batch indices into URIs that can be read under the domain.
batch_index: int ``get_batch()`` turns an integer index into its corresponding
) -> tuple[R, ...]: ``tuple[U, ...]``, reading the resources with ``_read_resources()``
Read batch files at the provided paths. in the tuple, treating them as providers of batched data.
``_read_resources()`` passes through to the attached domain logic,
which, although common, need not supply an explicit iterable of
batch items: we just access items with ``__getitem__()`` and may
ask for ``__len__``. So the returned URI group collection (this
method) does need to be iterable to measure the number of batches,
but the batch objects that are ultimately produced by these URI
groups need not be iterables themselves.
This method should operate on a single tuple from the list of batch raise NotImplementedError
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
all of the resources for a single batch and returns a tuple of the same
size with their contents.
Note: the dependence on a batch index is mostly here to make def _read_resources(
multi-dataset composition easier later. In-dataset, you don't need to self,
know the batch index to to simply process URIs, but across datasets you uri_group: tuple[U, ...],
need it to find out the origin of the batch (and process those URIs batch_index: int
accordingly). ) -> tuple[R, ...]:
Read batch files at the provided paths.
return tuple(self.domain.read(uri) for uri in uri_group) This method should operate on a single tuple from the list of batch
tuples returned by the ``_get_uri_groups()`` method. That is, it
reads all of the resources for a single batch and returns a tuple
of the same size with their contents.
# pulling the type variable out of the inline generic b/c `ty` has trouble Note: the dependence on a batch index is mostly here to make
# understanding bound type variables in subclasses (specifically with Self@) multi-dataset composition easier later. In-dataset, you don't need
T = TypeVar("T", bound=NamedTuple) to know the batch index to to simply process URIs, but across
datasets you need it to find out the origin of the batch (and
process those URIs accordingly).
class NamedTupleDataset[I](Dataset): return tuple(self.domain.read(uri) for uri in uri_group)
def __init__(self, data_list: list[I]) -> None:
self.data_list = data_list
def __len__(self) -> int: .. code-block:: python
return len(self.data_list)
def __getitem__(self, index: int) -> I: # pulling the type variable out of the inline generic b/c `ty` has
return self.data_list[index] # trouble understanding bound type variables in subclasses
# (specifically with Self@)
T = TypeVar("T", bound=NamedTuple)
class NamedTupleDataset[I](Dataset):
def __init__(self, data_list: list[I]) -> None:
self.data_list = data_list
def __len__(self) -> int:
return len(self.data_list)
def __getitem__(self, index: int) -> I:
return self.data_list[index]
""" """
import math import math
@@ -156,38 +163,41 @@ class BatchedDataset[U, R, I](Dataset):
which are used to concretize a domain ``Domain[U, R]``), and an item type which are used to concretize a domain ``Domain[U, R]``), and an item type
``T`` (which has a ``tuple`` upper bound). ``T`` (which has a ``tuple`` upper bound).
Pipeline overview:
``` .. admonition:: Pipeline overview
Domain -> [U] (get _batch_uris)
U -> R (domain access ; Rs provide batches)
R -> [I] (cache here ; _process_batch_data to use load_transform)
[I] -> I (human item obj ; _get_item)
I -> **P (final packed item ; __getitem__ to use transform)
```
Note^1: as far as positioning, this class is meant to play nice with .. code-block:: python
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
value add for this over the ``torch.Dataset`` base is almost entirely in
the logic it implements to map out of *batched resources* that are holding
data, and flattening it out into typical dataset items. There are also some
QoL items when it comes to splitting and balancing samples.
Note^2: even though ``Domains`` implement iterators over their URIs, this Domain -> [U] (get _batch_uris)
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk U -> R (domain access ; Rs provide batches)
over the resources that provide data, but we don't necessarily presuppose R -> [I] (cache here ; _process_batch_data to use load_transform)
an ordered walk over samples within batches. Point being: [I] -> I (human item obj ; _get_item)
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate I -> **P (final packed item ; __getitem__ to use transform)
superclass, even when we're working around iterable ``Domains``.
Note^3: transforms are expected to operate on ``I``-items and produce Note^1: as far as positioning, this class is meant to play nice with
``I``-items. They shouldn't be the "introducers" of ``I`` types from some PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
other intermediate representation, nor should they map from ``I`` to value add for this over the ``torch.Dataset`` base is almost entirely
something else. Point being: the dataset definition should be able to map in the logic it implements to map out of *batched resources* that are
resources ``R`` to ``I`` without a transform: that much should be baked holding data, and flattening it out into typical dataset items. There
into the class definition. If you find you're expecting the transform to do are also some QoL items when it comes to splitting and balancing
that for you, you should consider pulling in some common structure across samples.
the allowed transforms and make it a fixed part of the class.
Note^2: even though ``Domains`` implement iterators over their URIs,
this doesn't imply a ``BatchedDataset`` is iterable. This just means we
can walk over the resources that provide data, but we don't necessarily
presuppose an ordered walk over samples within batches. Point being:
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
superclass, even when we're working around iterable ``Domains``.
Note^3: transforms are expected to operate on ``I``-items and produce
``I``-items. They shouldn't be the "introducers" of ``I`` types from
some other intermediate representation, nor should they map from ``I``
to something else. Point being: the dataset definition should be able
to map resources ``R`` to ``I`` without a transform: that much should
be baked into the class definition. If you find you're expecting the
transform to do that for you, you should consider pulling in some
common structure across the allowed transforms and make it a fixed part
of the class.
""" """
def __init__( def __init__(

View File

@@ -268,40 +268,45 @@ class Trainer[I, K: EstimatorKwargs]:
""" """
Note: this method attempts to implement a general scheme for passing Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The needed items to the estimator's loss function from the dataloader. The
abstract `Estimator` base only requires the model output be provided abstract ``Estimator`` base only requires the model output be provided
for any given loss calculation, but concrete estimators will often for any given loss calculation, but concrete estimators will often
require additional arguments (e.g., labels or length masks, as require additional arguments (e.g., labels or length masks, as is the
is the case with sequential models). In any case, this method defers case with sequential models). In any case, this method defers any
any further logic to the `loss` method of the underlying estimator, so further logic to the ``loss`` method of the underlying estimator, so
one should take care to synchronize the sample structure with `dataset` one should take care to synchronize the sample structure with `dataset`
to match that expected by `self.estimator.loss(...)`. to match that expected by ``self.estimator.loss(...)``.
On batch_estimator_map: .. admonition:: On batch_estimator_map
Dataloader collate functions are responsible for mapping a collection Dataloader collate functions are responsible for mapping a
of items into an item of collections, roughly speaking. If items are collection of items into an item of collections, roughly speaking.
tuples of tensors, If items are tuples of tensors,
[ .. code-block::
( [1, 1], [1, 1] ),
( [2, 2], [2, 2] ),
( [3, 3], [3, 3] ),
]
the collate function maps back into the item skeleton, producing a [
single tuple of (stacked) tensors ( [1, 1], [1, 1] ),
( [2, 2], [2, 2] ),
( [3, 3], [3, 3] ),
]
( [[1, 1], the collate function maps back into the item skeleton, producing a
[2, 2], single tuple of (stacked) tensors
[3, 3]],
[[1, 1], .. code-block::
[2, 2],
[3, 3]] )
This function should map from batches (which should be *item shaped*, ( [[1, 1],
i.e., have an `I` skeleton, even if stacked items may be different on [2, 2],
the inside) into estimator keyword arguments (type `K`). [3, 3]],
[[1, 1],
[2, 2],
[3, 3]] )
This function should map from batches (which should be *item
shaped*, i.e., have an ``I`` skeleton, even if stacked items may be
different on the inside) into estimator keyword arguments (type
``K``).
Parameters: Parameters:
lr: learning rate (default: 1e-3) lr: learning rate (default: 1e-3)