Compare commits

..

2 Commits

Author SHA1 Message Date
faeef9c72a reformat docstrings for sphinx 2026-03-05 01:36:40 -08:00
805262dfc4 refactor Trainer train/val loop 2026-03-05 01:14:24 -08:00
15 changed files with 579 additions and 230 deletions

View File

@@ -2,16 +2,10 @@
Package summary goes here, ideally with a diagram
# Install
Installation instructions
The `trainlib` package can be installed from PyPI:
```sh
pip install <package>
```
or as a CLI tool
```sh
uv tool install <package>
pip install trainlib
```
# Development
@@ -20,16 +14,16 @@ uv tool install <package>
- Depending on needs, install the development dependencies with `uv sync
--extra dev`.
# Testing
## Testing
- To run the unit tests, make sure to first have the test dependencies
installed with `uv sync --extra test`, then run `make test`.
- For notebook testing, run `make install-kernel` to make the environment
available as a Jupyter kernel (to be selected when running notebooks).
# Documentation
## Documentation
- Install the documentation dependencies with `uv sync --extra doc`.
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
locally with `docs-serve`.
locally with `make docs-serve`.
# Development remarks
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a

20
doc/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

9
doc/_templates/autosummary.md vendored Normal file
View File

@@ -0,0 +1,9 @@
# {{ fullname | escape }}
```{automodule}
{{ fullname }}
:members:
:undoc-members:
:show-inheritance:
:imported-members:
```

8
doc/_templates/autosummary/module.rst vendored Normal file
View File

@@ -0,0 +1,8 @@
{{ fullname | escape | underline}}
.. automodule:: {{ fullname }}
:members:
:undoc-members:
:show-inheritance:
:imported-members:

49
doc/conf.py Normal file
View File

@@ -0,0 +1,49 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Project information ------------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "<package-name>"
copyright = "2025, Sam Griesemer"
author = "Sam Griesemer"
# -- General configuration ----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"sphinx.ext.autodoc",
# enables a directive to be specified manually that gathers module/object
# summary details in a table
"sphinx.ext.autosummary",
# allow viewing source in the HTML pages
"sphinx.ext.viewcode",
# only really applies to manual docs; docstrings still need RST-like
"myst_parser",
# enables Google-style docstring formats
"sphinx.ext.napoleon",
# external extension that allows arg types to be inferred by type hints
"sphinx_autodoc_typehints",
]
autosummary_generate = True
autosummary_imported_members = True
# include __init__ definitions in autodoc
autodoc_default_options = {
"special-members": "__init__",
}
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output --------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "furo"
html_static_path = ["_static"]
# html_sidebars = {
# '**': ['/modules.html'],
# }

29
doc/index.md Normal file
View File

@@ -0,0 +1,29 @@
# `<project-name>` package docs
{ref}`genindex`
{ref}`modindex`
{ref}`search`
```{eval-rst}
.. autosummary::
:nosignatures:
# list modules here for quick links
```
```{toctree}
:maxdepth: 3
:caption: Autoref
_autoref/<project-name>.rst
```
```{toctree}
:maxdepth: 3
:caption: Contents
reference/documentation/index
reference/site/index
```
```{include} ../README.md
```

35
doc/make.bat Normal file
View File

@@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

View File

@@ -0,0 +1,5 @@
# Documentation
```{toctree}
sphinx
```

View File

@@ -0,0 +1,111 @@
# Sphinx
The primary driver of this package's documentation is Sphinx's `autodoc` extension,
using the [Furo theme][1].
**High-level details**:
- `sphinx-apidoc` generates package-based documentation to the `_autoref/` directory,
with navigation available under "Autoref" in the sidebar.
- Markdown-based documentation files are manually written under the `reference/`
directory, showing up under "Contents" in the sidebar.
## Detailed directory structure
All files are placed under `docs/sphinx`:
- `_`-prefixed are Sphinx-managed directories
* `_build/html/` houses output HTML files
* `_autoref/` is the target for module-based RST files written by `autodoc`
- `reference/`: houses all manually written documentation (totally separate from
auto-generated package docs)
- `conf.py`: single Sphinx configuration file
- `index.md`: documentation index, setups up a persistent sidebar across all other pages
For manually written documentation under `reference/`, topics are nested as needed. Within
a nested directory `reference/<topic>`, an `index.md` should created with content like:
```
# <Topic>
\`\`\`{toctree}
:hidden:
sub-topic-1.rst
sub-topic-2.rst
...
\`\`\`
```
This will add the nested directory to the sidebar navigation, using the name set under the
top-level header. See [Markdown syntax][#markdown-syntax] for more details on the syntax.
## Sphinx autodoc
Sphinx's `autodoc` extension allows automatic generation of documents according to
(Python) subpackage structure and available docstrings. A few notes here:
- In the `conf.py` file, autodoc is enabled by adding `"sphinx.ext.autodoc"` to
the extensions list. `"sphinx.ext.viewcode"` can also be added to provide
links to source code.
- Documents are actually generated by calling the `sphinx-apidoc` CLI command. The
current Makefile uses the following call:
```sh
sphinx-apidoc --module-first -o docs/sphinx/_autoref/ localsys
```
This writes the automatically generated docs for modules in the package at the
local directory `localsys/` to the `docs/sphinx/_autoref` directory. These are
reStructuredText files by default.
* `--module-first` places the module-level descriptions at the top of the module page.
By default, this is placed at the bottom (oddly), and can be obscured by large lists
of subpackages if this flag isn't provided.
* See available `sphinx-apidoc` options [here][2], as well as more advanced config
[here][3].
## Markdown syntax
The `myst_parser` extension enables Markdown (or something close to it) to be used when
writing documentation files. The Sphinx directives can be difficult to track, and
they change slightly under the MyST Markdown syntax. The following are a few common
blocks:
**Page hierarchies**: the following will generate link hierarchy according to the provided
pages:
```
\`\`\`{toctree}
:maxdepth: <n>
:caption: <caption>
:hidden:
example-file-1
example-file-2
example-dir/index
...
\`\`\`
```
- `:maxdepth:` limits the depth of nesting
- `:caption:` title for the group of pages
- `:hidden:` if provided, links will only show in the sidebar (hidden on the page)
- Constituent files: listed files will be rendered as a link directly. If a listed file
has a `{toctree}` directive, this tree will be rendered in place of the page's link as a
dropdown. The dropdown will be named according to the file's top-level heading, and
clicking directly on the dropdown header will show that page's content. Files found in
the tree will be placed as links under the dropdown, recursively subject to same rules
described here.
**Include files**: the following will include file content
pages:
```
\`\`\`{include} README.md
\`\`\`
```
**Reference directives**
[1]: https://pradyunsg.me/furo/
[2]: https://www.sphinx-doc.org/en/master/man/sphinx-apidoc.html
[3]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#

20
example/dataset.py Normal file
View File

@@ -0,0 +1,20 @@
from typing import NamedTuple
from trainlib.domain import SequenceDomain
from trainlib.datasets.memory import TupleDataset
class Record(NamedTuple):
a: int
b: str
tl_domain = SequenceDomain[Record]([
Record(1, "1"),
Record(2, "2"),
])
class R0(TupleDataset[Record]):
item_tuple = Record
def _process_item_data(self, item_data, item_index):
return (item_data[0],)

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "trainlib"
version = "0.1.0"
version = "0.1.1"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13"
authors = [

View File

@@ -1,5 +1,5 @@
"""
Marginalizing out the modality layer:
.. admonition:: Marginalizing out the modality layer
With ``domain`` being an instance variable, one possible interpretation of
the object structures here is that one could completely abstract away
@@ -58,64 +58,71 @@ Marginalizing out the modality layer:
particular case of ``_process_batch_data()``, it feels much better when
it's on the inside.)
Holding:
@abstractmethod
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
Get URI groups for each batch.
.. admonition:: Holding area
If there's more than one URI per batch (e.g., a data file and a
metadata file), zip the URIs such that we have a tuple of URIs per
batch.
.. code-block:: python
Note that this effectively defines the index style over batches in the
attached domain. We get an ``int -> tuple[U, ...]`` map that turns
batch indices into URIs that can be read under the domain.
``get_batch()`` turns an integer index into its corresponding
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
the tuple, treating them as providers of batched data.
``_read_resources()`` passes through to the attached domain logic,
which, although common, need not supply an explicit iterable of batch
items: we just access items with ``__getitem__()`` and may ask for
``__len__``. So the returned URI group collection (this method) does
need to be iterable to measure the number of batches, but the batch
objects that are ultimately produced by these URI groups need not be
iterables themselves.
@abstractmethod
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
Get URI groups for each batch.
raise NotImplementedError
If there's more than one URI per batch (e.g., a data file and a
metadata file), zip the URIs such that we have a tuple of URIs per
batch.
def _read_resources(
self,
uri_group: tuple[U, ...],
batch_index: int
) -> tuple[R, ...]:
Read batch files at the provided paths.
Note that this effectively defines the index style over batches in
the attached domain. We get an ``int -> tuple[U, ...]`` map that
turns batch indices into URIs that can be read under the domain.
``get_batch()`` turns an integer index into its corresponding
``tuple[U, ...]``, reading the resources with ``_read_resources()``
in the tuple, treating them as providers of batched data.
``_read_resources()`` passes through to the attached domain logic,
which, although common, need not supply an explicit iterable of
batch items: we just access items with ``__getitem__()`` and may
ask for ``__len__``. So the returned URI group collection (this
method) does need to be iterable to measure the number of batches,
but the batch objects that are ultimately produced by these URI
groups need not be iterables themselves.
This method should operate on a single tuple from the list of batch
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
all of the resources for a single batch and returns a tuple of the same
size with their contents.
raise NotImplementedError
Note: the dependence on a batch index is mostly here to make
multi-dataset composition easier later. In-dataset, you don't need to
know the batch index to to simply process URIs, but across datasets you
need it to find out the origin of the batch (and process those URIs
accordingly).
def _read_resources(
self,
uri_group: tuple[U, ...],
batch_index: int
) -> tuple[R, ...]:
Read batch files at the provided paths.
return tuple(self.domain.read(uri) for uri in uri_group)
This method should operate on a single tuple from the list of batch
tuples returned by the ``_get_uri_groups()`` method. That is, it
reads all of the resources for a single batch and returns a tuple
of the same size with their contents.
# pulling the type variable out of the inline generic b/c `ty` has trouble
# understanding bound type variables in subclasses (specifically with Self@)
T = TypeVar("T", bound=NamedTuple)
Note: the dependence on a batch index is mostly here to make
multi-dataset composition easier later. In-dataset, you don't need
to know the batch index to to simply process URIs, but across
datasets you need it to find out the origin of the batch (and
process those URIs accordingly).
class NamedTupleDataset[I](Dataset):
def __init__(self, data_list: list[I]) -> None:
self.data_list = data_list
return tuple(self.domain.read(uri) for uri in uri_group)
def __len__(self) -> int:
return len(self.data_list)
.. code-block:: python
def __getitem__(self, index: int) -> I:
return self.data_list[index]
# pulling the type variable out of the inline generic b/c `ty` has
# trouble understanding bound type variables in subclasses
# (specifically with Self@)
T = TypeVar("T", bound=NamedTuple)
class NamedTupleDataset[I](Dataset):
def __init__(self, data_list: list[I]) -> None:
self.data_list = data_list
def __len__(self) -> int:
return len(self.data_list)
def __getitem__(self, index: int) -> I:
return self.data_list[index]
"""
import math
@@ -156,38 +163,41 @@ class BatchedDataset[U, R, I](Dataset):
which are used to concretize a domain ``Domain[U, R]``), and an item type
``T`` (which has a ``tuple`` upper bound).
Pipeline overview:
.. admonition:: Pipeline overview
```
Domain -> [U] (get _batch_uris)
U -> R (domain access ; Rs provide batches)
R -> [I] (cache here ; _process_batch_data to use load_transform)
[I] -> I (human item obj ; _get_item)
I -> **P (final packed item ; __getitem__ to use transform)
```
.. code-block:: python
Domain -> [U] (get _batch_uris)
U -> R (domain access ; Rs provide batches)
R -> [I] (cache here ; _process_batch_data to use load_transform)
[I] -> I (human item obj ; _get_item)
I -> **P (final packed item ; __getitem__ to use transform)
Note^1: as far as positioning, this class is meant to play nice with
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
value add for this over the ``torch.Dataset`` base is almost entirely in
the logic it implements to map out of *batched resources* that are holding
data, and flattening it out into typical dataset items. There are also some
QoL items when it comes to splitting and balancing samples.
Note^1: as far as positioning, this class is meant to play nice with
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
value add for this over the ``torch.Dataset`` base is almost entirely
in the logic it implements to map out of *batched resources* that are
holding data, and flattening it out into typical dataset items. There
are also some QoL items when it comes to splitting and balancing
samples.
Note^2: even though ``Domains`` implement iterators over their URIs, this
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk
over the resources that provide data, but we don't necessarily presuppose
an ordered walk over samples within batches. Point being:
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
superclass, even when we're working around iterable ``Domains``.
Note^2: even though ``Domains`` implement iterators over their URIs,
this doesn't imply a ``BatchedDataset`` is iterable. This just means we
can walk over the resources that provide data, but we don't necessarily
presuppose an ordered walk over samples within batches. Point being:
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
superclass, even when we're working around iterable ``Domains``.
Note^3: transforms are expected to operate on ``I``-items and produce
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
other intermediate representation, nor should they map from ``I`` to
something else. Point being: the dataset definition should be able to map
resources ``R`` to ``I`` without a transform: that much should be baked
into the class definition. If you find you're expecting the transform to do
that for you, you should consider pulling in some common structure across
the allowed transforms and make it a fixed part of the class.
Note^3: transforms are expected to operate on ``I``-items and produce
``I``-items. They shouldn't be the "introducers" of ``I`` types from
some other intermediate representation, nor should they map from ``I``
to something else. Point being: the dataset definition should be able
to map resources ``R`` to ``I`` without a transform: that much should
be baked into the class definition. If you find you're expecting the
transform to do that for you, you should consider pulling in some
common structure across the allowed transforms and make it a fixed part
of the class.
"""
def __init__(

View File

@@ -24,7 +24,7 @@ from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from trainlib.util.type import OptimizerKwargs
from trainlib.utils.type import OptimizerKwargs
logger: logging.Logger = logging.getLogger(__name__)

View File

@@ -1,10 +1,7 @@
import logging
from collections.abc import Generator
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.nn.utils.parametrizations import weight_norm
logger: logging.Logger = logging.getLogger(__name__)

View File

@@ -11,6 +11,7 @@ from collections.abc import Callable
import torch
from tqdm import tqdm
from torch import cuda, Tensor
from torch.optim import Optimizer
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
@@ -98,6 +99,151 @@ class Trainer[I, K: EstimatorKwargs]:
self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch(
self,
train_loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
optimizers: tuple[Optimizer, ...],
writer: SummaryWriter,
max_grad_norm: float | None = None,
) -> list[float]:
"""
Train the estimator for a single epoch.
"""
train_loss_sums = []
self.estimator.train()
with tqdm(train_loader, unit="batch") as train_epoch:
for i, batch_data in enumerate(train_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
inputs = est_kwargs["inputs"]
# one-time logging
if self._step == 0:
writer.add_graph(
ModelWrapper(self.estimator),
est_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=False,
**est_kwargs
)
train_losses = self.estimator.loss(**est_kwargs)
train_loss_items = []
for o_idx, optimizer in enumerate(optimizers):
optimizer.zero_grad()
train_loss = next(train_losses)
if len(train_loss_sums) <= o_idx:
train_loss_sums.append(0.0)
train_loss_item = train_loss.item()
train_loss_sums[o_idx] += train_loss_item
train_loss_items.append(train_loss_item)
train_loss.backward()
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
clip_grad_norm_(
self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm
)
optimizer.step()
self._step += len(inputs)
for train_loss_item, train_loss_sum in zip(
train_loss_items,
train_loss_sums,
strict=True,
):
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
self._add_summary_item("train_loss", train_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(
f"train_{metric_name}",
metric_value
)
self.estimator.epoch_step()
for li, train_loss_sum in enumerate(train_loss_sums):
self._add_summary_item(
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
)
return train_loss_sums
def _val_epoch(
self,
val_loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
optimizers: tuple[Optimizer, ...],
writer: SummaryWriter,
) -> list[float]:
"""
Perform and record validation scores for a single epoch.
"""
val_loss_sums = []
self.estimator.eval()
with tqdm(val_loader, unit="batch") as val_epoch:
for i, batch_data in enumerate(val_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=True,
**est_kwargs
)
val_losses = self.estimator.loss(**est_kwargs)
val_loss_items = []
for o_idx in range(len(optimizers)):
val_loss = next(val_losses)
if len(val_loss_sums) <= o_idx:
val_loss_sums.append(0.0)
val_loss_item = val_loss.item()
val_loss_sums[o_idx] += val_loss_item
val_loss_items.append(val_loss_item)
for val_loss_item, val_loss_sum in zip(
val_loss_items,
val_loss_sums,
strict=True,
):
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
self._add_summary_item("val_loss", val_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"val_{metric_name}", metric_value)
for li, val_loss_sum in enumerate(val_loss_sums):
self._add_summary_item(
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
)
# convergence of multiple losses may be ambiguous
self._val_loss = sum(val_loss_sums) / len(val_loader)
return val_loss_sums
def train(
self,
dataset: BatchedDataset[..., ..., I],
@@ -122,40 +268,45 @@ class Trainer[I, K: EstimatorKwargs]:
"""
Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The
abstract `Estimator` base only requires the model output be provided
abstract ``Estimator`` base only requires the model output be provided
for any given loss calculation, but concrete estimators will often
require additional arguments (e.g., labels or length masks, as
is the case with sequential models). In any case, this method defers
any further logic to the `loss` method of the underlying estimator, so
require additional arguments (e.g., labels or length masks, as is the
case with sequential models). In any case, this method defers any
further logic to the ``loss`` method of the underlying estimator, so
one should take care to synchronize the sample structure with `dataset`
to match that expected by `self.estimator.loss(...)`.
to match that expected by ``self.estimator.loss(...)``.
.. admonition:: On batch_estimator_map
On batch_estimator_map:
Dataloader collate functions are responsible for mapping a
collection of items into an item of collections, roughly speaking.
If items are tuples of tensors,
Dataloader collate functions are responsible for mapping a collection
of items into an item of collections, roughly speaking. If items are
tuples of tensors,
.. code-block::
[
( [1, 1], [1, 1] ),
( [2, 2], [2, 2] ),
( [3, 3], [3, 3] ),
]
[
( [1, 1], [1, 1] ),
( [2, 2], [2, 2] ),
( [3, 3], [3, 3] ),
]
the collate function maps back into the item skeleton, producing a
single tuple of (stacked) tensors
the collate function maps back into the item skeleton, producing a
single tuple of (stacked) tensors
.. code-block::
( [[1, 1],
[2, 2],
[3, 3]],
( [[1, 1],
[2, 2],
[3, 3]],
[[1, 1],
[2, 2],
[3, 3]] )
[[1, 1],
[2, 2],
[3, 3]] )
This function should map from batches (which should be *item shaped*,
i.e., have an `I` skeleton, even if stacked items may be different on
the inside) into estimator keyword arguments (type `K`).
This function should map from batches (which should be *item
shaped*, i.e., have an ``I`` skeleton, even if stacked items may be
different on the inside) into estimator keyword arguments (type
``K``).
Parameters:
lr: learning rate (default: 1e-3)
@@ -212,123 +363,32 @@ class Trainer[I, K: EstimatorKwargs]:
while self._epoch <= max_epochs and not self._converged(
self._epoch, stop_after_epochs
):
print(f"Training epoch {self._epoch}/{max_epochs}...")
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...")
train_frac = f"{self._epoch}/{max_epochs}"
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...")
print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time()
train_loss_sums = []
self.estimator.train()
with tqdm(train_loader, unit="batch") as train_epoch:
for i, batch_data in enumerate(train_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
inputs = est_kwargs["inputs"]
# one-time logging
if self._step == 0:
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=False,
**est_kwargs
)
train_losses = self.estimator.loss(**est_kwargs)
train_loss_items = []
for o_idx, optimizer in enumerate(optimizers):
optimizer.zero_grad()
train_loss = next(train_losses)
if len(train_loss_sums) <= o_idx:
train_loss_sums.append(0.0)
train_loss_item = train_loss.item()
train_loss_sums[o_idx] += train_loss_item
train_loss_items.append(train_loss_item)
train_loss.backward()
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
opt_params = self._get_optimizer_parameters(optimizer)
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
optimizer.step()
self._step += len(inputs)
for train_loss_item, train_loss_sum in zip(
train_loss_items,
train_loss_sums,
strict=True,
):
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
self._add_summary_item("train_loss", train_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"train_{metric_name}", metric_value)
self.estimator.epoch_step()
for li, train_loss_sum in enumerate(train_loss_sums):
self._add_summary_item(
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
)
self._train_epoch(
train_loader,
batch_estimator_map,
optimizers,
writer,
max_grad_norm
)
if val_frac > 0:
val_loss_sums = []
self.estimator.eval()
with tqdm(val_loader, unit="batch") as val_epoch:
for i, batch_data in enumerate(val_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
inputs = est_kwargs["inputs"]
self._val_epoch(
val_loader,
batch_estimator_map,
optimizers,
writer,
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=True,
**est_kwargs
)
val_losses = self.estimator.loss(**est_kwargs)
val_loss_items = []
for o_idx in range(len(optimizers)):
val_loss = next(val_losses)
if len(val_loss_sums) <= o_idx:
val_loss_sums.append(0.0)
val_loss_item = val_loss.item()
val_loss_sums[o_idx] += val_loss_item
val_loss_items.append(val_loss_item)
for val_loss_item, val_loss_sum in zip(
val_loss_items,
val_loss_sums,
strict=True,
):
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
self._add_summary_item("val_loss", val_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"val_{metric_name}", metric_value)
for li, val_loss_sum in enumerate(val_loss_sums):
self._add_summary_item(
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
)
# convergence of multiple losses may be ambiguous
self._val_loss = sum(val_loss_sums) / len(val_loader)
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
self._add_summary_item(
"epoch_time_sec",
time.time() - epoch_start_time
)
if self._epoch % summarize_every == 0:
self._summarize(writer, self._epoch)
@@ -336,7 +396,9 @@ class Trainer[I, K: EstimatorKwargs]:
# save checkpoint
if self._epoch % chkpt_every == 0:
self.save_model(
self._epoch, self.chkpt_dir, dir_prefix
self._epoch,
self.chkpt_dir,
dir_prefix
)
self._epoch += 1