From 805262dfc44e1a5b1b53e3350521765af593e585 Mon Sep 17 00:00:00 2001 From: smgr Date: Thu, 5 Mar 2026 01:14:24 -0800 Subject: [PATCH] refactor Trainer train/val loop --- README.md | 16 +- doc/Makefile | 20 ++ doc/_templates/autosummary.md | 9 + doc/_templates/autosummary/module.rst | 8 + doc/conf.py | 49 +++++ doc/index.md | 29 +++ doc/make.bat | 35 ++++ doc/reference/documentation/index.md | 5 + doc/reference/documentation/sphinx.md | 111 ++++++++++ example/dataset.py | 20 ++ trainlib/estimator.py | 2 +- trainlib/estimators/tdnn.py | 3 - trainlib/trainer.py | 283 ++++++++++++++++---------- 13 files changed, 462 insertions(+), 128 deletions(-) create mode 100644 doc/Makefile create mode 100644 doc/_templates/autosummary.md create mode 100644 doc/_templates/autosummary/module.rst create mode 100644 doc/conf.py create mode 100644 doc/index.md create mode 100644 doc/make.bat create mode 100644 doc/reference/documentation/index.md create mode 100644 doc/reference/documentation/sphinx.md create mode 100644 example/dataset.py diff --git a/README.md b/README.md index 551f90e..64bb6e2 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,10 @@ Package summary goes here, ideally with a diagram # Install -Installation instructions +The `trainlib` package can be installed from PyPI: ```sh -pip install -``` - -or as a CLI tool - -```sh -uv tool install +pip install trainlib ``` # Development @@ -20,16 +14,16 @@ uv tool install - Depending on needs, install the development dependencies with `uv sync --extra dev`. -# Testing +## Testing - To run the unit tests, make sure to first have the test dependencies installed with `uv sync --extra test`, then run `make test`. - For notebook testing, run `make install-kernel` to make the environment available as a Jupyter kernel (to be selected when running notebooks). -# Documentation +## Documentation - Install the documentation dependencies with `uv sync --extra doc`. - Run `make docs-build` (optionally preceded by `make docs-clean`), and serve - locally with `docs-serve`. + locally with `make docs-serve`. # Development remarks - Across `Trainer` / `Estimator` / `Dataset`, I've considered a diff --git a/doc/Makefile b/doc/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/doc/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/doc/_templates/autosummary.md b/doc/_templates/autosummary.md new file mode 100644 index 0000000..216f1df --- /dev/null +++ b/doc/_templates/autosummary.md @@ -0,0 +1,9 @@ +# {{ fullname | escape }} + +```{automodule} +{{ fullname }} +:members: +:undoc-members: +:show-inheritance: +:imported-members: +``` diff --git a/doc/_templates/autosummary/module.rst b/doc/_templates/autosummary/module.rst new file mode 100644 index 0000000..6d5a51d --- /dev/null +++ b/doc/_templates/autosummary/module.rst @@ -0,0 +1,8 @@ +{{ fullname | escape | underline}} + +.. automodule:: {{ fullname }} + :members: + :undoc-members: + :show-inheritance: + :imported-members: + diff --git a/doc/conf.py b/doc/conf.py new file mode 100644 index 0000000..9753af1 --- /dev/null +++ b/doc/conf.py @@ -0,0 +1,49 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ------------------------------------------------------ +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "" +copyright = "2025, Sam Griesemer" +author = "Sam Griesemer" + +# -- General configuration ---------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.autodoc", + # enables a directive to be specified manually that gathers module/object + # summary details in a table + "sphinx.ext.autosummary", + # allow viewing source in the HTML pages + "sphinx.ext.viewcode", + # only really applies to manual docs; docstrings still need RST-like + "myst_parser", + # enables Google-style docstring formats + "sphinx.ext.napoleon", + # external extension that allows arg types to be inferred by type hints + "sphinx_autodoc_typehints", +] +autosummary_generate = True +autosummary_imported_members = True + +# include __init__ definitions in autodoc +autodoc_default_options = { + "special-members": "__init__", +} + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + + +# -- Options for HTML output -------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "furo" +html_static_path = ["_static"] +# html_sidebars = { +# '**': ['/modules.html'], +# } diff --git a/doc/index.md b/doc/index.md new file mode 100644 index 0000000..f8fc36c --- /dev/null +++ b/doc/index.md @@ -0,0 +1,29 @@ +# `` package docs +{ref}`genindex` +{ref}`modindex` +{ref}`search` + +```{eval-rst} +.. autosummary:: + :nosignatures: + + # list modules here for quick links +``` + +```{toctree} +:maxdepth: 3 +:caption: Autoref + +_autoref/.rst +``` + +```{toctree} +:maxdepth: 3 +:caption: Contents + +reference/documentation/index +reference/site/index +``` + +```{include} ../README.md +``` diff --git a/doc/make.bat b/doc/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/doc/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/doc/reference/documentation/index.md b/doc/reference/documentation/index.md new file mode 100644 index 0000000..58b65a4 --- /dev/null +++ b/doc/reference/documentation/index.md @@ -0,0 +1,5 @@ +# Documentation + +```{toctree} +sphinx +``` diff --git a/doc/reference/documentation/sphinx.md b/doc/reference/documentation/sphinx.md new file mode 100644 index 0000000..33d6f27 --- /dev/null +++ b/doc/reference/documentation/sphinx.md @@ -0,0 +1,111 @@ +# Sphinx +The primary driver of this package's documentation is Sphinx's `autodoc` extension, +using the [Furo theme][1]. + +**High-level details**: + +- `sphinx-apidoc` generates package-based documentation to the `_autoref/` directory, + with navigation available under "Autoref" in the sidebar. +- Markdown-based documentation files are manually written under the `reference/` + directory, showing up under "Contents" in the sidebar. + +## Detailed directory structure +All files are placed under `docs/sphinx`: + +- `_`-prefixed are Sphinx-managed directories + * `_build/html/` houses output HTML files + * `_autoref/` is the target for module-based RST files written by `autodoc` +- `reference/`: houses all manually written documentation (totally separate from + auto-generated package docs) +- `conf.py`: single Sphinx configuration file +- `index.md`: documentation index, setups up a persistent sidebar across all other pages + +For manually written documentation under `reference/`, topics are nested as needed. Within +a nested directory `reference/`, an `index.md` should created with content like: + +``` +# + +\`\`\`{toctree} +:hidden: + +sub-topic-1.rst +sub-topic-2.rst +... +\`\`\` +``` + +This will add the nested directory to the sidebar navigation, using the name set under the +top-level header. See [Markdown syntax][#markdown-syntax] for more details on the syntax. + +## Sphinx autodoc +Sphinx's `autodoc` extension allows automatic generation of documents according to +(Python) subpackage structure and available docstrings. A few notes here: + +- In the `conf.py` file, autodoc is enabled by adding `"sphinx.ext.autodoc"` to + the extensions list. `"sphinx.ext.viewcode"` can also be added to provide + links to source code. +- Documents are actually generated by calling the `sphinx-apidoc` CLI command. The + current Makefile uses the following call: + + ```sh + sphinx-apidoc --module-first -o docs/sphinx/_autoref/ localsys + ``` + + This writes the automatically generated docs for modules in the package at the + local directory `localsys/` to the `docs/sphinx/_autoref` directory. These are + reStructuredText files by default. + * `--module-first` places the module-level descriptions at the top of the module page. + By default, this is placed at the bottom (oddly), and can be obscured by large lists + of subpackages if this flag isn't provided. + * See available `sphinx-apidoc` options [here][2], as well as more advanced config + [here][3]. + + +## Markdown syntax +The `myst_parser` extension enables Markdown (or something close to it) to be used when +writing documentation files. The Sphinx directives can be difficult to track, and +they change slightly under the MyST Markdown syntax. The following are a few common +blocks: + +**Page hierarchies**: the following will generate link hierarchy according to the provided +pages: + +``` +\`\`\`{toctree} +:maxdepth: +:caption: +:hidden: + +example-file-1 +example-file-2 +example-dir/index +... +\`\`\` +``` + +- `:maxdepth:` limits the depth of nesting +- `:caption:` title for the group of pages +- `:hidden:` if provided, links will only show in the sidebar (hidden on the page) +- Constituent files: listed files will be rendered as a link directly. If a listed file + has a `{toctree}` directive, this tree will be rendered in place of the page's link as a + dropdown. The dropdown will be named according to the file's top-level heading, and + clicking directly on the dropdown header will show that page's content. Files found in + the tree will be placed as links under the dropdown, recursively subject to same rules + described here. + +**Include files**: the following will include file content +pages: + +``` +\`\`\`{include} README.md +\`\`\` +``` + +**Reference directives** + + +[1]: https://pradyunsg.me/furo/ +[2]: https://www.sphinx-doc.org/en/master/man/sphinx-apidoc.html +[3]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html# + diff --git a/example/dataset.py b/example/dataset.py new file mode 100644 index 0000000..28d87b2 --- /dev/null +++ b/example/dataset.py @@ -0,0 +1,20 @@ +from typing import NamedTuple + +from trainlib.domain import SequenceDomain +from trainlib.datasets.memory import TupleDataset + + +class Record(NamedTuple): + a: int + b: str + +tl_domain = SequenceDomain[Record]([ + Record(1, "1"), + Record(2, "2"), +]) + +class R0(TupleDataset[Record]): + item_tuple = Record + + def _process_item_data(self, item_data, item_index): + return (item_data[0],) diff --git a/trainlib/estimator.py b/trainlib/estimator.py index cb1257c..b763806 100644 --- a/trainlib/estimator.py +++ b/trainlib/estimator.py @@ -24,7 +24,7 @@ from torch import nn, Tensor from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from trainlib.util.type import OptimizerKwargs +from trainlib.utils.type import OptimizerKwargs logger: logging.Logger = logging.getLogger(__name__) diff --git a/trainlib/estimators/tdnn.py b/trainlib/estimators/tdnn.py index 6d9a312..04146ca 100644 --- a/trainlib/estimators/tdnn.py +++ b/trainlib/estimators/tdnn.py @@ -1,10 +1,7 @@ import logging -from collections.abc import Generator -import torch import torch.nn.functional as F from torch import nn, Tensor -from torch.optim import Optimizer from torch.nn.utils.parametrizations import weight_norm logger: logging.Logger = logging.getLogger(__name__) diff --git a/trainlib/trainer.py b/trainlib/trainer.py index c96f174..9f1d4dc 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -11,6 +11,7 @@ from collections.abc import Callable import torch from tqdm import tqdm from torch import cuda, Tensor +from torch.optim import Optimizer from torch.nn.utils import clip_grad_norm_ from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter @@ -98,6 +99,151 @@ class Trainer[I, K: EstimatorKwargs]: self._stagnant_epochs = 0 self._best_model_state_dict: dict[str, Any] = {} + def _train_epoch( + self, + train_loader: DataLoader, + batch_estimator_map: Callable[[I, Self], K], + optimizers: tuple[Optimizer, ...], + writer: SummaryWriter, + max_grad_norm: float | None = None, + ) -> list[float]: + """ + Train the estimator for a single epoch. + """ + + train_loss_sums = [] + self.estimator.train() + with tqdm(train_loader, unit="batch") as train_epoch: + for i, batch_data in enumerate(train_epoch): + est_kwargs = batch_estimator_map(batch_data, self) + inputs = est_kwargs["inputs"] + + # one-time logging + if self._step == 0: + writer.add_graph( + ModelWrapper(self.estimator), + est_kwargs + ) + + # once-per-epoch logging + if i == 0: + self.estimator.epoch_write( + writer, + step=self._step, + val=False, + **est_kwargs + ) + + train_losses = self.estimator.loss(**est_kwargs) + train_loss_items = [] + for o_idx, optimizer in enumerate(optimizers): + optimizer.zero_grad() + train_loss = next(train_losses) + + if len(train_loss_sums) <= o_idx: + train_loss_sums.append(0.0) + + train_loss_item = train_loss.item() + train_loss_sums[o_idx] += train_loss_item + train_loss_items.append(train_loss_item) + + train_loss.backward() + + # clip gradients for optimizer's parameters + if max_grad_norm is not None: + clip_grad_norm_( + self._get_optimizer_parameters(optimizer), + max_norm=max_grad_norm + ) + + optimizer.step() + + self._step += len(inputs) + + for train_loss_item, train_loss_sum in zip( + train_loss_items, + train_loss_sums, + strict=True, + ): + train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}") + self._add_summary_item("train_loss", train_loss_item) + + estimator_metrics = self.estimator.metrics(**est_kwargs) + for metric_name, metric_value in estimator_metrics.items(): + self._add_summary_item( + f"train_{metric_name}", + metric_value + ) + + self.estimator.epoch_step() + + for li, train_loss_sum in enumerate(train_loss_sums): + self._add_summary_item( + f"train_loss{li}_epoch", train_loss_sum / len(train_loader) + ) + + return train_loss_sums + + def _val_epoch( + self, + val_loader: DataLoader, + batch_estimator_map: Callable[[I, Self], K], + optimizers: tuple[Optimizer, ...], + writer: SummaryWriter, + ) -> list[float]: + """ + Perform and record validation scores for a single epoch. + """ + + val_loss_sums = [] + self.estimator.eval() + with tqdm(val_loader, unit="batch") as val_epoch: + for i, batch_data in enumerate(val_epoch): + est_kwargs = batch_estimator_map(batch_data, self) + + # once-per-epoch logging + if i == 0: + self.estimator.epoch_write( + writer, + step=self._step, + val=True, + **est_kwargs + ) + + val_losses = self.estimator.loss(**est_kwargs) + val_loss_items = [] + for o_idx in range(len(optimizers)): + val_loss = next(val_losses) + + if len(val_loss_sums) <= o_idx: + val_loss_sums.append(0.0) + + val_loss_item = val_loss.item() + val_loss_sums[o_idx] += val_loss_item + val_loss_items.append(val_loss_item) + + for val_loss_item, val_loss_sum in zip( + val_loss_items, + val_loss_sums, + strict=True, + ): + val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}") + self._add_summary_item("val_loss", val_loss_item) + + estimator_metrics = self.estimator.metrics(**est_kwargs) + for metric_name, metric_value in estimator_metrics.items(): + self._add_summary_item(f"val_{metric_name}", metric_value) + + for li, val_loss_sum in enumerate(val_loss_sums): + self._add_summary_item( + f"val_loss{li}_epoch", val_loss_sum / len(val_loader) + ) + + # convergence of multiple losses may be ambiguous + self._val_loss = sum(val_loss_sums) / len(val_loader) + + return val_loss_sums + def train( self, dataset: BatchedDataset[..., ..., I], @@ -212,123 +358,32 @@ class Trainer[I, K: EstimatorKwargs]: while self._epoch <= max_epochs and not self._converged( self._epoch, stop_after_epochs ): - print(f"Training epoch {self._epoch}/{max_epochs}...") - print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...") + train_frac = f"{self._epoch}/{max_epochs}" + stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}" + print(f"Training epoch {train_frac}...") + print(f"Stagnant epochs {stag_frac}...") epoch_start_time = time.time() - train_loss_sums = [] - self.estimator.train() - with tqdm(train_loader, unit="batch") as train_epoch: - for i, batch_data in enumerate(train_epoch): - est_kwargs = batch_estimator_map(batch_data, self) - inputs = est_kwargs["inputs"] - - # one-time logging - if self._step == 0: - writer.add_graph(ModelWrapper(self.estimator), est_kwargs) - - # once-per-epoch logging - if i == 0: - self.estimator.epoch_write( - writer, - step=self._step, - val=False, - **est_kwargs - ) - - train_losses = self.estimator.loss(**est_kwargs) - train_loss_items = [] - for o_idx, optimizer in enumerate(optimizers): - optimizer.zero_grad() - train_loss = next(train_losses) - - if len(train_loss_sums) <= o_idx: - train_loss_sums.append(0.0) - - train_loss_item = train_loss.item() - train_loss_sums[o_idx] += train_loss_item - train_loss_items.append(train_loss_item) - - train_loss.backward() - - # clip gradients for optimizer's parameters - if max_grad_norm is not None: - opt_params = self._get_optimizer_parameters(optimizer) - clip_grad_norm_(opt_params, max_norm=max_grad_norm) - - optimizer.step() - - self._step += len(inputs) - - for train_loss_item, train_loss_sum in zip( - train_loss_items, - train_loss_sums, - strict=True, - ): - train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}") - self._add_summary_item("train_loss", train_loss_item) - - estimator_metrics = self.estimator.metrics(**est_kwargs) - for metric_name, metric_value in estimator_metrics.items(): - self._add_summary_item(f"train_{metric_name}", metric_value) - - self.estimator.epoch_step() - - for li, train_loss_sum in enumerate(train_loss_sums): - self._add_summary_item( - f"train_loss{li}_epoch", train_loss_sum / len(train_loader) - ) + self._train_epoch( + train_loader, + batch_estimator_map, + optimizers, + writer, + max_grad_norm + ) if val_frac > 0: - val_loss_sums = [] - self.estimator.eval() - with tqdm(val_loader, unit="batch") as val_epoch: - for i, batch_data in enumerate(val_epoch): - est_kwargs = batch_estimator_map(batch_data, self) - inputs = est_kwargs["inputs"] + self._val_epoch( + val_loader, + batch_estimator_map, + optimizers, + writer, + ) - # once-per-epoch logging - if i == 0: - self.estimator.epoch_write( - writer, - step=self._step, - val=True, - **est_kwargs - ) - - val_losses = self.estimator.loss(**est_kwargs) - val_loss_items = [] - for o_idx in range(len(optimizers)): - val_loss = next(val_losses) - - if len(val_loss_sums) <= o_idx: - val_loss_sums.append(0.0) - - val_loss_item = val_loss.item() - val_loss_sums[o_idx] += val_loss_item - val_loss_items.append(val_loss_item) - - for val_loss_item, val_loss_sum in zip( - val_loss_items, - val_loss_sums, - strict=True, - ): - val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}") - self._add_summary_item("val_loss", val_loss_item) - - estimator_metrics = self.estimator.metrics(**est_kwargs) - for metric_name, metric_value in estimator_metrics.items(): - self._add_summary_item(f"val_{metric_name}", metric_value) - - for li, val_loss_sum in enumerate(val_loss_sums): - self._add_summary_item( - f"val_loss{li}_epoch", val_loss_sum / len(val_loader) - ) - - # convergence of multiple losses may be ambiguous - self._val_loss = sum(val_loss_sums) / len(val_loader) - - self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time) + self._add_summary_item( + "epoch_time_sec", + time.time() - epoch_start_time + ) if self._epoch % summarize_every == 0: self._summarize(writer, self._epoch) @@ -336,7 +391,9 @@ class Trainer[I, K: EstimatorKwargs]: # save checkpoint if self._epoch % chkpt_every == 0: self.save_model( - self._epoch, self.chkpt_dir, dir_prefix + self._epoch, + self.chkpt_dir, + dir_prefix ) self._epoch += 1