From faeef9c72aa37d77c645c526eb81f4a1cadc60cd Mon Sep 17 00:00:00 2001 From: smgr Date: Thu, 5 Mar 2026 01:36:40 -0800 Subject: [PATCH] reformat docstrings for sphinx --- pyproject.toml | 2 +- trainlib/dataset.py | 162 +++++++++++++++++++++++--------------------- trainlib/trainer.py | 55 ++++++++------- 3 files changed, 117 insertions(+), 102 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9f434fa..b491b97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "trainlib" -version = "0.1.0" +version = "0.1.1" description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." requires-python = ">=3.13" authors = [ diff --git a/trainlib/dataset.py b/trainlib/dataset.py index cb2251f..8429bbc 100644 --- a/trainlib/dataset.py +++ b/trainlib/dataset.py @@ -1,5 +1,5 @@ """ -Marginalizing out the modality layer: +.. admonition:: Marginalizing out the modality layer With ``domain`` being an instance variable, one possible interpretation of the object structures here is that one could completely abstract away @@ -58,64 +58,71 @@ Marginalizing out the modality layer: particular case of ``_process_batch_data()``, it feels much better when it's on the inside.) -Holding: - @abstractmethod - def _get_uri_groups(self) -> Iterable[tuple[U, ...]]: - Get URI groups for each batch. +.. admonition:: Holding area - If there's more than one URI per batch (e.g., a data file and a - metadata file), zip the URIs such that we have a tuple of URIs per - batch. + .. code-block:: python - Note that this effectively defines the index style over batches in the - attached domain. We get an ``int -> tuple[U, ...]`` map that turns - batch indices into URIs that can be read under the domain. - ``get_batch()`` turns an integer index into its corresponding - ``tuple[U, ...]``, reading the resources with ``_read_resources()`` in - the tuple, treating them as providers of batched data. - ``_read_resources()`` passes through to the attached domain logic, - which, although common, need not supply an explicit iterable of batch - items: we just access items with ``__getitem__()`` and may ask for - ``__len__``. So the returned URI group collection (this method) does - need to be iterable to measure the number of batches, but the batch - objects that are ultimately produced by these URI groups need not be - iterables themselves. + @abstractmethod + def _get_uri_groups(self) -> Iterable[tuple[U, ...]]: + Get URI groups for each batch. - raise NotImplementedError + If there's more than one URI per batch (e.g., a data file and a + metadata file), zip the URIs such that we have a tuple of URIs per + batch. - def _read_resources( - self, - uri_group: tuple[U, ...], - batch_index: int - ) -> tuple[R, ...]: - Read batch files at the provided paths. + Note that this effectively defines the index style over batches in + the attached domain. We get an ``int -> tuple[U, ...]`` map that + turns batch indices into URIs that can be read under the domain. + ``get_batch()`` turns an integer index into its corresponding + ``tuple[U, ...]``, reading the resources with ``_read_resources()`` + in the tuple, treating them as providers of batched data. + ``_read_resources()`` passes through to the attached domain logic, + which, although common, need not supply an explicit iterable of + batch items: we just access items with ``__getitem__()`` and may + ask for ``__len__``. So the returned URI group collection (this + method) does need to be iterable to measure the number of batches, + but the batch objects that are ultimately produced by these URI + groups need not be iterables themselves. - This method should operate on a single tuple from the list of batch - tuples returned by the ``_get_uri_groups()`` method. That is, it reads - all of the resources for a single batch and returns a tuple of the same - size with their contents. + raise NotImplementedError - Note: the dependence on a batch index is mostly here to make - multi-dataset composition easier later. In-dataset, you don't need to - know the batch index to to simply process URIs, but across datasets you - need it to find out the origin of the batch (and process those URIs - accordingly). + def _read_resources( + self, + uri_group: tuple[U, ...], + batch_index: int + ) -> tuple[R, ...]: + Read batch files at the provided paths. - return tuple(self.domain.read(uri) for uri in uri_group) + This method should operate on a single tuple from the list of batch + tuples returned by the ``_get_uri_groups()`` method. That is, it + reads all of the resources for a single batch and returns a tuple + of the same size with their contents. -# pulling the type variable out of the inline generic b/c `ty` has trouble -# understanding bound type variables in subclasses (specifically with Self@) -T = TypeVar("T", bound=NamedTuple) + Note: the dependence on a batch index is mostly here to make + multi-dataset composition easier later. In-dataset, you don't need + to know the batch index to to simply process URIs, but across + datasets you need it to find out the origin of the batch (and + process those URIs accordingly). -class NamedTupleDataset[I](Dataset): - def __init__(self, data_list: list[I]) -> None: - self.data_list = data_list + return tuple(self.domain.read(uri) for uri in uri_group) - def __len__(self) -> int: - return len(self.data_list) + .. code-block:: python - def __getitem__(self, index: int) -> I: - return self.data_list[index] + # pulling the type variable out of the inline generic b/c `ty` has + # trouble understanding bound type variables in subclasses + # (specifically with Self@) + + T = TypeVar("T", bound=NamedTuple) + + class NamedTupleDataset[I](Dataset): + def __init__(self, data_list: list[I]) -> None: + self.data_list = data_list + + def __len__(self) -> int: + return len(self.data_list) + + def __getitem__(self, index: int) -> I: + return self.data_list[index] """ import math @@ -156,38 +163,41 @@ class BatchedDataset[U, R, I](Dataset): which are used to concretize a domain ``Domain[U, R]``), and an item type ``T`` (which has a ``tuple`` upper bound). - Pipeline overview: + + .. admonition:: Pipeline overview - ``` - Domain -> [U] (get _batch_uris) - U -> R (domain access ; Rs provide batches) - R -> [I] (cache here ; _process_batch_data to use load_transform) - [I] -> I (human item obj ; _get_item) - I -> **P (final packed item ; __getitem__ to use transform) - ``` + .. code-block:: python + + Domain -> [U] (get _batch_uris) + U -> R (domain access ; Rs provide batches) + R -> [I] (cache here ; _process_batch_data to use load_transform) + [I] -> I (human item obj ; _get_item) + I -> **P (final packed item ; __getitem__ to use transform) - Note^1: as far as positioning, this class is meant to play nice with - PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The - value add for this over the ``torch.Dataset`` base is almost entirely in - the logic it implements to map out of *batched resources* that are holding - data, and flattening it out into typical dataset items. There are also some - QoL items when it comes to splitting and balancing samples. + Note^1: as far as positioning, this class is meant to play nice with + PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The + value add for this over the ``torch.Dataset`` base is almost entirely + in the logic it implements to map out of *batched resources* that are + holding data, and flattening it out into typical dataset items. There + are also some QoL items when it comes to splitting and balancing + samples. - Note^2: even though ``Domains`` implement iterators over their URIs, this - doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk - over the resources that provide data, but we don't necessarily presuppose - an ordered walk over samples within batches. Point being: - ``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate - superclass, even when we're working around iterable ``Domains``. + Note^2: even though ``Domains`` implement iterators over their URIs, + this doesn't imply a ``BatchedDataset`` is iterable. This just means we + can walk over the resources that provide data, but we don't necessarily + presuppose an ordered walk over samples within batches. Point being: + ``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate + superclass, even when we're working around iterable ``Domains``. - Note^3: transforms are expected to operate on ``I``-items and produce - ``I``-items. They shouldn't be the "introducers" of ``I`` types from some - other intermediate representation, nor should they map from ``I`` to - something else. Point being: the dataset definition should be able to map - resources ``R`` to ``I`` without a transform: that much should be baked - into the class definition. If you find you're expecting the transform to do - that for you, you should consider pulling in some common structure across - the allowed transforms and make it a fixed part of the class. + Note^3: transforms are expected to operate on ``I``-items and produce + ``I``-items. They shouldn't be the "introducers" of ``I`` types from + some other intermediate representation, nor should they map from ``I`` + to something else. Point being: the dataset definition should be able + to map resources ``R`` to ``I`` without a transform: that much should + be baked into the class definition. If you find you're expecting the + transform to do that for you, you should consider pulling in some + common structure across the allowed transforms and make it a fixed part + of the class. """ def __init__( diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 9f1d4dc..0d6882c 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -268,40 +268,45 @@ class Trainer[I, K: EstimatorKwargs]: """ Note: this method attempts to implement a general scheme for passing needed items to the estimator's loss function from the dataloader. The - abstract `Estimator` base only requires the model output be provided + abstract ``Estimator`` base only requires the model output be provided for any given loss calculation, but concrete estimators will often - require additional arguments (e.g., labels or length masks, as - is the case with sequential models). In any case, this method defers - any further logic to the `loss` method of the underlying estimator, so + require additional arguments (e.g., labels or length masks, as is the + case with sequential models). In any case, this method defers any + further logic to the ``loss`` method of the underlying estimator, so one should take care to synchronize the sample structure with `dataset` - to match that expected by `self.estimator.loss(...)`. + to match that expected by ``self.estimator.loss(...)``. + + .. admonition:: On batch_estimator_map - On batch_estimator_map: + Dataloader collate functions are responsible for mapping a + collection of items into an item of collections, roughly speaking. + If items are tuples of tensors, - Dataloader collate functions are responsible for mapping a collection - of items into an item of collections, roughly speaking. If items are - tuples of tensors, + .. code-block:: - [ - ( [1, 1], [1, 1] ), - ( [2, 2], [2, 2] ), - ( [3, 3], [3, 3] ), - ] + [ + ( [1, 1], [1, 1] ), + ( [2, 2], [2, 2] ), + ( [3, 3], [3, 3] ), + ] - the collate function maps back into the item skeleton, producing a - single tuple of (stacked) tensors + the collate function maps back into the item skeleton, producing a + single tuple of (stacked) tensors + + .. code-block:: - ( [[1, 1], - [2, 2], - [3, 3]], + ( [[1, 1], + [2, 2], + [3, 3]], - [[1, 1], - [2, 2], - [3, 3]] ) + [[1, 1], + [2, 2], + [3, 3]] ) - This function should map from batches (which should be *item shaped*, - i.e., have an `I` skeleton, even if stacked items may be different on - the inside) into estimator keyword arguments (type `K`). + This function should map from batches (which should be *item + shaped*, i.e., have an ``I`` skeleton, even if stacked items may be + different on the inside) into estimator keyword arguments (type + ``K``). Parameters: lr: learning rate (default: 1e-3)