Compare commits
2 Commits
c473e48b5b
...
faeef9c72a
| Author | SHA1 | Date | |
|---|---|---|---|
| faeef9c72a | |||
| 805262dfc4 |
16
README.md
16
README.md
@@ -2,16 +2,10 @@
|
||||
Package summary goes here, ideally with a diagram
|
||||
|
||||
# Install
|
||||
Installation instructions
|
||||
The `trainlib` package can be installed from PyPI:
|
||||
|
||||
```sh
|
||||
pip install <package>
|
||||
```
|
||||
|
||||
or as a CLI tool
|
||||
|
||||
```sh
|
||||
uv tool install <package>
|
||||
pip install trainlib
|
||||
```
|
||||
|
||||
# Development
|
||||
@@ -20,16 +14,16 @@ uv tool install <package>
|
||||
- Depending on needs, install the development dependencies with `uv sync
|
||||
--extra dev`.
|
||||
|
||||
# Testing
|
||||
## Testing
|
||||
- To run the unit tests, make sure to first have the test dependencies
|
||||
installed with `uv sync --extra test`, then run `make test`.
|
||||
- For notebook testing, run `make install-kernel` to make the environment
|
||||
available as a Jupyter kernel (to be selected when running notebooks).
|
||||
|
||||
# Documentation
|
||||
## Documentation
|
||||
- Install the documentation dependencies with `uv sync --extra doc`.
|
||||
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
|
||||
locally with `docs-serve`.
|
||||
locally with `make docs-serve`.
|
||||
|
||||
# Development remarks
|
||||
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a
|
||||
|
||||
20
doc/Makefile
Normal file
20
doc/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
9
doc/_templates/autosummary.md
vendored
Normal file
9
doc/_templates/autosummary.md
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# {{ fullname | escape }}
|
||||
|
||||
```{automodule}
|
||||
{{ fullname }}
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:imported-members:
|
||||
```
|
||||
8
doc/_templates/autosummary/module.rst
vendored
Normal file
8
doc/_templates/autosummary/module.rst
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. automodule:: {{ fullname }}
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:imported-members:
|
||||
|
||||
49
doc/conf.py
Normal file
49
doc/conf.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Project information ------------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = "<package-name>"
|
||||
copyright = "2025, Sam Griesemer"
|
||||
author = "Sam Griesemer"
|
||||
|
||||
# -- General configuration ----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
# enables a directive to be specified manually that gathers module/object
|
||||
# summary details in a table
|
||||
"sphinx.ext.autosummary",
|
||||
# allow viewing source in the HTML pages
|
||||
"sphinx.ext.viewcode",
|
||||
# only really applies to manual docs; docstrings still need RST-like
|
||||
"myst_parser",
|
||||
# enables Google-style docstring formats
|
||||
"sphinx.ext.napoleon",
|
||||
# external extension that allows arg types to be inferred by type hints
|
||||
"sphinx_autodoc_typehints",
|
||||
]
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
# include __init__ definitions in autodoc
|
||||
autodoc_default_options = {
|
||||
"special-members": "__init__",
|
||||
}
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output --------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "furo"
|
||||
html_static_path = ["_static"]
|
||||
# html_sidebars = {
|
||||
# '**': ['/modules.html'],
|
||||
# }
|
||||
29
doc/index.md
Normal file
29
doc/index.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# `<project-name>` package docs
|
||||
{ref}`genindex`
|
||||
{ref}`modindex`
|
||||
{ref}`search`
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
# list modules here for quick links
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 3
|
||||
:caption: Autoref
|
||||
|
||||
_autoref/<project-name>.rst
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 3
|
||||
:caption: Contents
|
||||
|
||||
reference/documentation/index
|
||||
reference/site/index
|
||||
```
|
||||
|
||||
```{include} ../README.md
|
||||
```
|
||||
35
doc/make.bat
Normal file
35
doc/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
||||
5
doc/reference/documentation/index.md
Normal file
5
doc/reference/documentation/index.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Documentation
|
||||
|
||||
```{toctree}
|
||||
sphinx
|
||||
```
|
||||
111
doc/reference/documentation/sphinx.md
Normal file
111
doc/reference/documentation/sphinx.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# Sphinx
|
||||
The primary driver of this package's documentation is Sphinx's `autodoc` extension,
|
||||
using the [Furo theme][1].
|
||||
|
||||
**High-level details**:
|
||||
|
||||
- `sphinx-apidoc` generates package-based documentation to the `_autoref/` directory,
|
||||
with navigation available under "Autoref" in the sidebar.
|
||||
- Markdown-based documentation files are manually written under the `reference/`
|
||||
directory, showing up under "Contents" in the sidebar.
|
||||
|
||||
## Detailed directory structure
|
||||
All files are placed under `docs/sphinx`:
|
||||
|
||||
- `_`-prefixed are Sphinx-managed directories
|
||||
* `_build/html/` houses output HTML files
|
||||
* `_autoref/` is the target for module-based RST files written by `autodoc`
|
||||
- `reference/`: houses all manually written documentation (totally separate from
|
||||
auto-generated package docs)
|
||||
- `conf.py`: single Sphinx configuration file
|
||||
- `index.md`: documentation index, setups up a persistent sidebar across all other pages
|
||||
|
||||
For manually written documentation under `reference/`, topics are nested as needed. Within
|
||||
a nested directory `reference/<topic>`, an `index.md` should created with content like:
|
||||
|
||||
```
|
||||
# <Topic>
|
||||
|
||||
\`\`\`{toctree}
|
||||
:hidden:
|
||||
|
||||
sub-topic-1.rst
|
||||
sub-topic-2.rst
|
||||
...
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
This will add the nested directory to the sidebar navigation, using the name set under the
|
||||
top-level header. See [Markdown syntax][#markdown-syntax] for more details on the syntax.
|
||||
|
||||
## Sphinx autodoc
|
||||
Sphinx's `autodoc` extension allows automatic generation of documents according to
|
||||
(Python) subpackage structure and available docstrings. A few notes here:
|
||||
|
||||
- In the `conf.py` file, autodoc is enabled by adding `"sphinx.ext.autodoc"` to
|
||||
the extensions list. `"sphinx.ext.viewcode"` can also be added to provide
|
||||
links to source code.
|
||||
- Documents are actually generated by calling the `sphinx-apidoc` CLI command. The
|
||||
current Makefile uses the following call:
|
||||
|
||||
```sh
|
||||
sphinx-apidoc --module-first -o docs/sphinx/_autoref/ localsys
|
||||
```
|
||||
|
||||
This writes the automatically generated docs for modules in the package at the
|
||||
local directory `localsys/` to the `docs/sphinx/_autoref` directory. These are
|
||||
reStructuredText files by default.
|
||||
* `--module-first` places the module-level descriptions at the top of the module page.
|
||||
By default, this is placed at the bottom (oddly), and can be obscured by large lists
|
||||
of subpackages if this flag isn't provided.
|
||||
* See available `sphinx-apidoc` options [here][2], as well as more advanced config
|
||||
[here][3].
|
||||
|
||||
|
||||
## Markdown syntax
|
||||
The `myst_parser` extension enables Markdown (or something close to it) to be used when
|
||||
writing documentation files. The Sphinx directives can be difficult to track, and
|
||||
they change slightly under the MyST Markdown syntax. The following are a few common
|
||||
blocks:
|
||||
|
||||
**Page hierarchies**: the following will generate link hierarchy according to the provided
|
||||
pages:
|
||||
|
||||
```
|
||||
\`\`\`{toctree}
|
||||
:maxdepth: <n>
|
||||
:caption: <caption>
|
||||
:hidden:
|
||||
|
||||
example-file-1
|
||||
example-file-2
|
||||
example-dir/index
|
||||
...
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
- `:maxdepth:` limits the depth of nesting
|
||||
- `:caption:` title for the group of pages
|
||||
- `:hidden:` if provided, links will only show in the sidebar (hidden on the page)
|
||||
- Constituent files: listed files will be rendered as a link directly. If a listed file
|
||||
has a `{toctree}` directive, this tree will be rendered in place of the page's link as a
|
||||
dropdown. The dropdown will be named according to the file's top-level heading, and
|
||||
clicking directly on the dropdown header will show that page's content. Files found in
|
||||
the tree will be placed as links under the dropdown, recursively subject to same rules
|
||||
described here.
|
||||
|
||||
**Include files**: the following will include file content
|
||||
pages:
|
||||
|
||||
```
|
||||
\`\`\`{include} README.md
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
**Reference directives**
|
||||
|
||||
|
||||
[1]: https://pradyunsg.me/furo/
|
||||
[2]: https://www.sphinx-doc.org/en/master/man/sphinx-apidoc.html
|
||||
[3]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#
|
||||
|
||||
20
example/dataset.py
Normal file
20
example/dataset.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from trainlib.domain import SequenceDomain
|
||||
from trainlib.datasets.memory import TupleDataset
|
||||
|
||||
|
||||
class Record(NamedTuple):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
tl_domain = SequenceDomain[Record]([
|
||||
Record(1, "1"),
|
||||
Record(2, "2"),
|
||||
])
|
||||
|
||||
class R0(TupleDataset[Record]):
|
||||
item_tuple = Record
|
||||
|
||||
def _process_item_data(self, item_data, item_index):
|
||||
return (item_data[0],)
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "trainlib"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||
requires-python = ">=3.13"
|
||||
authors = [
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Marginalizing out the modality layer:
|
||||
.. admonition:: Marginalizing out the modality layer
|
||||
|
||||
With ``domain`` being an instance variable, one possible interpretation of
|
||||
the object structures here is that one could completely abstract away
|
||||
@@ -58,7 +58,10 @@ Marginalizing out the modality layer:
|
||||
particular case of ``_process_batch_data()``, it feels much better when
|
||||
it's on the inside.)
|
||||
|
||||
Holding:
|
||||
.. admonition:: Holding area
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@abstractmethod
|
||||
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
|
||||
Get URI groups for each batch.
|
||||
@@ -67,19 +70,19 @@ Holding:
|
||||
metadata file), zip the URIs such that we have a tuple of URIs per
|
||||
batch.
|
||||
|
||||
Note that this effectively defines the index style over batches in the
|
||||
attached domain. We get an ``int -> tuple[U, ...]`` map that turns
|
||||
batch indices into URIs that can be read under the domain.
|
||||
Note that this effectively defines the index style over batches in
|
||||
the attached domain. We get an ``int -> tuple[U, ...]`` map that
|
||||
turns batch indices into URIs that can be read under the domain.
|
||||
``get_batch()`` turns an integer index into its corresponding
|
||||
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
|
||||
the tuple, treating them as providers of batched data.
|
||||
``tuple[U, ...]``, reading the resources with ``_read_resources()``
|
||||
in the tuple, treating them as providers of batched data.
|
||||
``_read_resources()`` passes through to the attached domain logic,
|
||||
which, although common, need not supply an explicit iterable of batch
|
||||
items: we just access items with ``__getitem__()`` and may ask for
|
||||
``__len__``. So the returned URI group collection (this method) does
|
||||
need to be iterable to measure the number of batches, but the batch
|
||||
objects that are ultimately produced by these URI groups need not be
|
||||
iterables themselves.
|
||||
which, although common, need not supply an explicit iterable of
|
||||
batch items: we just access items with ``__getitem__()`` and may
|
||||
ask for ``__len__``. So the returned URI group collection (this
|
||||
method) does need to be iterable to measure the number of batches,
|
||||
but the batch objects that are ultimately produced by these URI
|
||||
groups need not be iterables themselves.
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -91,20 +94,24 @@ Holding:
|
||||
Read batch files at the provided paths.
|
||||
|
||||
This method should operate on a single tuple from the list of batch
|
||||
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
|
||||
all of the resources for a single batch and returns a tuple of the same
|
||||
size with their contents.
|
||||
tuples returned by the ``_get_uri_groups()`` method. That is, it
|
||||
reads all of the resources for a single batch and returns a tuple
|
||||
of the same size with their contents.
|
||||
|
||||
Note: the dependence on a batch index is mostly here to make
|
||||
multi-dataset composition easier later. In-dataset, you don't need to
|
||||
know the batch index to to simply process URIs, but across datasets you
|
||||
need it to find out the origin of the batch (and process those URIs
|
||||
accordingly).
|
||||
multi-dataset composition easier later. In-dataset, you don't need
|
||||
to know the batch index to to simply process URIs, but across
|
||||
datasets you need it to find out the origin of the batch (and
|
||||
process those URIs accordingly).
|
||||
|
||||
return tuple(self.domain.read(uri) for uri in uri_group)
|
||||
|
||||
# pulling the type variable out of the inline generic b/c `ty` has trouble
|
||||
# understanding bound type variables in subclasses (specifically with Self@)
|
||||
.. code-block:: python
|
||||
|
||||
# pulling the type variable out of the inline generic b/c `ty` has
|
||||
# trouble understanding bound type variables in subclasses
|
||||
# (specifically with Self@)
|
||||
|
||||
T = TypeVar("T", bound=NamedTuple)
|
||||
|
||||
class NamedTupleDataset[I](Dataset):
|
||||
@@ -156,38 +163,41 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
which are used to concretize a domain ``Domain[U, R]``), and an item type
|
||||
``T`` (which has a ``tuple`` upper bound).
|
||||
|
||||
Pipeline overview:
|
||||
|
||||
```
|
||||
.. admonition:: Pipeline overview
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Domain -> [U] (get _batch_uris)
|
||||
U -> R (domain access ; Rs provide batches)
|
||||
R -> [I] (cache here ; _process_batch_data to use load_transform)
|
||||
[I] -> I (human item obj ; _get_item)
|
||||
I -> **P (final packed item ; __getitem__ to use transform)
|
||||
```
|
||||
|
||||
Note^1: as far as positioning, this class is meant to play nice with
|
||||
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
|
||||
value add for this over the ``torch.Dataset`` base is almost entirely in
|
||||
the logic it implements to map out of *batched resources* that are holding
|
||||
data, and flattening it out into typical dataset items. There are also some
|
||||
QoL items when it comes to splitting and balancing samples.
|
||||
value add for this over the ``torch.Dataset`` base is almost entirely
|
||||
in the logic it implements to map out of *batched resources* that are
|
||||
holding data, and flattening it out into typical dataset items. There
|
||||
are also some QoL items when it comes to splitting and balancing
|
||||
samples.
|
||||
|
||||
Note^2: even though ``Domains`` implement iterators over their URIs, this
|
||||
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk
|
||||
over the resources that provide data, but we don't necessarily presuppose
|
||||
an ordered walk over samples within batches. Point being:
|
||||
Note^2: even though ``Domains`` implement iterators over their URIs,
|
||||
this doesn't imply a ``BatchedDataset`` is iterable. This just means we
|
||||
can walk over the resources that provide data, but we don't necessarily
|
||||
presuppose an ordered walk over samples within batches. Point being:
|
||||
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
||||
superclass, even when we're working around iterable ``Domains``.
|
||||
|
||||
Note^3: transforms are expected to operate on ``I``-items and produce
|
||||
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
|
||||
other intermediate representation, nor should they map from ``I`` to
|
||||
something else. Point being: the dataset definition should be able to map
|
||||
resources ``R`` to ``I`` without a transform: that much should be baked
|
||||
into the class definition. If you find you're expecting the transform to do
|
||||
that for you, you should consider pulling in some common structure across
|
||||
the allowed transforms and make it a fixed part of the class.
|
||||
``I``-items. They shouldn't be the "introducers" of ``I`` types from
|
||||
some other intermediate representation, nor should they map from ``I``
|
||||
to something else. Point being: the dataset definition should be able
|
||||
to map resources ``R`` to ``I`` without a transform: that much should
|
||||
be baked into the class definition. If you find you're expecting the
|
||||
transform to do that for you, you should consider pulling in some
|
||||
common structure across the allowed transforms and make it a fixed part
|
||||
of the class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -24,7 +24,7 @@ from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from trainlib.util.type import OptimizerKwargs
|
||||
from trainlib.utils.type import OptimizerKwargs
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,6 +11,7 @@ from collections.abc import Callable
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch import cuda, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@@ -98,6 +99,151 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
self._stagnant_epochs = 0
|
||||
self._best_model_state_dict: dict[str, Any] = {}
|
||||
|
||||
def _train_epoch(
|
||||
self,
|
||||
train_loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
optimizers: tuple[Optimizer, ...],
|
||||
writer: SummaryWriter,
|
||||
max_grad_norm: float | None = None,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Train the estimator for a single epoch.
|
||||
"""
|
||||
|
||||
train_loss_sums = []
|
||||
self.estimator.train()
|
||||
with tqdm(train_loader, unit="batch") as train_epoch:
|
||||
for i, batch_data in enumerate(train_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
|
||||
# one-time logging
|
||||
if self._step == 0:
|
||||
writer.add_graph(
|
||||
ModelWrapper(self.estimator),
|
||||
est_kwargs
|
||||
)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=False,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
train_losses = self.estimator.loss(**est_kwargs)
|
||||
train_loss_items = []
|
||||
for o_idx, optimizer in enumerate(optimizers):
|
||||
optimizer.zero_grad()
|
||||
train_loss = next(train_losses)
|
||||
|
||||
if len(train_loss_sums) <= o_idx:
|
||||
train_loss_sums.append(0.0)
|
||||
|
||||
train_loss_item = train_loss.item()
|
||||
train_loss_sums[o_idx] += train_loss_item
|
||||
train_loss_items.append(train_loss_item)
|
||||
|
||||
train_loss.backward()
|
||||
|
||||
# clip gradients for optimizer's parameters
|
||||
if max_grad_norm is not None:
|
||||
clip_grad_norm_(
|
||||
self._get_optimizer_parameters(optimizer),
|
||||
max_norm=max_grad_norm
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
self._step += len(inputs)
|
||||
|
||||
for train_loss_item, train_loss_sum in zip(
|
||||
train_loss_items,
|
||||
train_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("train_loss", train_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(
|
||||
f"train_{metric_name}",
|
||||
metric_value
|
||||
)
|
||||
|
||||
self.estimator.epoch_step()
|
||||
|
||||
for li, train_loss_sum in enumerate(train_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
||||
)
|
||||
|
||||
return train_loss_sums
|
||||
|
||||
def _val_epoch(
|
||||
self,
|
||||
val_loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
optimizers: tuple[Optimizer, ...],
|
||||
writer: SummaryWriter,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Perform and record validation scores for a single epoch.
|
||||
"""
|
||||
|
||||
val_loss_sums = []
|
||||
self.estimator.eval()
|
||||
with tqdm(val_loader, unit="batch") as val_epoch:
|
||||
for i, batch_data in enumerate(val_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=True,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
val_losses = self.estimator.loss(**est_kwargs)
|
||||
val_loss_items = []
|
||||
for o_idx in range(len(optimizers)):
|
||||
val_loss = next(val_losses)
|
||||
|
||||
if len(val_loss_sums) <= o_idx:
|
||||
val_loss_sums.append(0.0)
|
||||
|
||||
val_loss_item = val_loss.item()
|
||||
val_loss_sums[o_idx] += val_loss_item
|
||||
val_loss_items.append(val_loss_item)
|
||||
|
||||
for val_loss_item, val_loss_sum in zip(
|
||||
val_loss_items,
|
||||
val_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("val_loss", val_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"val_{metric_name}", metric_value)
|
||||
|
||||
for li, val_loss_sum in enumerate(val_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
||||
)
|
||||
|
||||
# convergence of multiple losses may be ambiguous
|
||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||
|
||||
return val_loss_sums
|
||||
|
||||
def train(
|
||||
self,
|
||||
dataset: BatchedDataset[..., ..., I],
|
||||
@@ -122,19 +268,21 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
"""
|
||||
Note: this method attempts to implement a general scheme for passing
|
||||
needed items to the estimator's loss function from the dataloader. The
|
||||
abstract `Estimator` base only requires the model output be provided
|
||||
abstract ``Estimator`` base only requires the model output be provided
|
||||
for any given loss calculation, but concrete estimators will often
|
||||
require additional arguments (e.g., labels or length masks, as
|
||||
is the case with sequential models). In any case, this method defers
|
||||
any further logic to the `loss` method of the underlying estimator, so
|
||||
require additional arguments (e.g., labels or length masks, as is the
|
||||
case with sequential models). In any case, this method defers any
|
||||
further logic to the ``loss`` method of the underlying estimator, so
|
||||
one should take care to synchronize the sample structure with `dataset`
|
||||
to match that expected by `self.estimator.loss(...)`.
|
||||
to match that expected by ``self.estimator.loss(...)``.
|
||||
|
||||
On batch_estimator_map:
|
||||
.. admonition:: On batch_estimator_map
|
||||
|
||||
Dataloader collate functions are responsible for mapping a collection
|
||||
of items into an item of collections, roughly speaking. If items are
|
||||
tuples of tensors,
|
||||
Dataloader collate functions are responsible for mapping a
|
||||
collection of items into an item of collections, roughly speaking.
|
||||
If items are tuples of tensors,
|
||||
|
||||
.. code-block::
|
||||
|
||||
[
|
||||
( [1, 1], [1, 1] ),
|
||||
@@ -145,6 +293,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
the collate function maps back into the item skeleton, producing a
|
||||
single tuple of (stacked) tensors
|
||||
|
||||
.. code-block::
|
||||
|
||||
( [[1, 1],
|
||||
[2, 2],
|
||||
[3, 3]],
|
||||
@@ -153,9 +303,10 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
[2, 2],
|
||||
[3, 3]] )
|
||||
|
||||
This function should map from batches (which should be *item shaped*,
|
||||
i.e., have an `I` skeleton, even if stacked items may be different on
|
||||
the inside) into estimator keyword arguments (type `K`).
|
||||
This function should map from batches (which should be *item
|
||||
shaped*, i.e., have an ``I`` skeleton, even if stacked items may be
|
||||
different on the inside) into estimator keyword arguments (type
|
||||
``K``).
|
||||
|
||||
Parameters:
|
||||
lr: learning rate (default: 1e-3)
|
||||
@@ -212,131 +363,42 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
while self._epoch <= max_epochs and not self._converged(
|
||||
self._epoch, stop_after_epochs
|
||||
):
|
||||
print(f"Training epoch {self._epoch}/{max_epochs}...")
|
||||
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...")
|
||||
train_frac = f"{self._epoch}/{max_epochs}"
|
||||
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||
print(f"Training epoch {train_frac}...")
|
||||
print(f"Stagnant epochs {stag_frac}...")
|
||||
|
||||
epoch_start_time = time.time()
|
||||
train_loss_sums = []
|
||||
self.estimator.train()
|
||||
with tqdm(train_loader, unit="batch") as train_epoch:
|
||||
for i, batch_data in enumerate(train_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
|
||||
# one-time logging
|
||||
if self._step == 0:
|
||||
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
self._train_epoch(
|
||||
train_loader,
|
||||
batch_estimator_map,
|
||||
optimizers,
|
||||
writer,
|
||||
step=self._step,
|
||||
val=False,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
train_losses = self.estimator.loss(**est_kwargs)
|
||||
train_loss_items = []
|
||||
for o_idx, optimizer in enumerate(optimizers):
|
||||
optimizer.zero_grad()
|
||||
train_loss = next(train_losses)
|
||||
|
||||
if len(train_loss_sums) <= o_idx:
|
||||
train_loss_sums.append(0.0)
|
||||
|
||||
train_loss_item = train_loss.item()
|
||||
train_loss_sums[o_idx] += train_loss_item
|
||||
train_loss_items.append(train_loss_item)
|
||||
|
||||
train_loss.backward()
|
||||
|
||||
# clip gradients for optimizer's parameters
|
||||
if max_grad_norm is not None:
|
||||
opt_params = self._get_optimizer_parameters(optimizer)
|
||||
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
self._step += len(inputs)
|
||||
|
||||
for train_loss_item, train_loss_sum in zip(
|
||||
train_loss_items,
|
||||
train_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("train_loss", train_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"train_{metric_name}", metric_value)
|
||||
|
||||
self.estimator.epoch_step()
|
||||
|
||||
for li, train_loss_sum in enumerate(train_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
||||
max_grad_norm
|
||||
)
|
||||
|
||||
if val_frac > 0:
|
||||
val_loss_sums = []
|
||||
self.estimator.eval()
|
||||
with tqdm(val_loader, unit="batch") as val_epoch:
|
||||
for i, batch_data in enumerate(val_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
self._val_epoch(
|
||||
val_loader,
|
||||
batch_estimator_map,
|
||||
optimizers,
|
||||
writer,
|
||||
step=self._step,
|
||||
val=True,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
val_losses = self.estimator.loss(**est_kwargs)
|
||||
val_loss_items = []
|
||||
for o_idx in range(len(optimizers)):
|
||||
val_loss = next(val_losses)
|
||||
|
||||
if len(val_loss_sums) <= o_idx:
|
||||
val_loss_sums.append(0.0)
|
||||
|
||||
val_loss_item = val_loss.item()
|
||||
val_loss_sums[o_idx] += val_loss_item
|
||||
val_loss_items.append(val_loss_item)
|
||||
|
||||
for val_loss_item, val_loss_sum in zip(
|
||||
val_loss_items,
|
||||
val_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("val_loss", val_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"val_{metric_name}", metric_value)
|
||||
|
||||
for li, val_loss_sum in enumerate(val_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
||||
"epoch_time_sec",
|
||||
time.time() - epoch_start_time
|
||||
)
|
||||
|
||||
# convergence of multiple losses may be ambiguous
|
||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||
|
||||
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
|
||||
|
||||
if self._epoch % summarize_every == 0:
|
||||
self._summarize(writer, self._epoch)
|
||||
|
||||
# save checkpoint
|
||||
if self._epoch % chkpt_every == 0:
|
||||
self.save_model(
|
||||
self._epoch, self.chkpt_dir, dir_prefix
|
||||
self._epoch,
|
||||
self.chkpt_dir,
|
||||
dir_prefix
|
||||
)
|
||||
|
||||
self._epoch += 1
|
||||
|
||||
Reference in New Issue
Block a user