Compare commits
2 Commits
c473e48b5b
...
faeef9c72a
| Author | SHA1 | Date | |
|---|---|---|---|
| faeef9c72a | |||
| 805262dfc4 |
16
README.md
16
README.md
@@ -2,16 +2,10 @@
|
|||||||
Package summary goes here, ideally with a diagram
|
Package summary goes here, ideally with a diagram
|
||||||
|
|
||||||
# Install
|
# Install
|
||||||
Installation instructions
|
The `trainlib` package can be installed from PyPI:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
pip install <package>
|
pip install trainlib
|
||||||
```
|
|
||||||
|
|
||||||
or as a CLI tool
|
|
||||||
|
|
||||||
```sh
|
|
||||||
uv tool install <package>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
@@ -20,16 +14,16 @@ uv tool install <package>
|
|||||||
- Depending on needs, install the development dependencies with `uv sync
|
- Depending on needs, install the development dependencies with `uv sync
|
||||||
--extra dev`.
|
--extra dev`.
|
||||||
|
|
||||||
# Testing
|
## Testing
|
||||||
- To run the unit tests, make sure to first have the test dependencies
|
- To run the unit tests, make sure to first have the test dependencies
|
||||||
installed with `uv sync --extra test`, then run `make test`.
|
installed with `uv sync --extra test`, then run `make test`.
|
||||||
- For notebook testing, run `make install-kernel` to make the environment
|
- For notebook testing, run `make install-kernel` to make the environment
|
||||||
available as a Jupyter kernel (to be selected when running notebooks).
|
available as a Jupyter kernel (to be selected when running notebooks).
|
||||||
|
|
||||||
# Documentation
|
## Documentation
|
||||||
- Install the documentation dependencies with `uv sync --extra doc`.
|
- Install the documentation dependencies with `uv sync --extra doc`.
|
||||||
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
|
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
|
||||||
locally with `docs-serve`.
|
locally with `make docs-serve`.
|
||||||
|
|
||||||
# Development remarks
|
# Development remarks
|
||||||
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a
|
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a
|
||||||
|
|||||||
20
doc/Makefile
Normal file
20
doc/Makefile
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
#
|
||||||
|
|
||||||
|
# You can set these variables from the command line, and also
|
||||||
|
# from the environment for the first two.
|
||||||
|
SPHINXOPTS ?=
|
||||||
|
SPHINXBUILD ?= sphinx-build
|
||||||
|
SOURCEDIR = .
|
||||||
|
BUILDDIR = _build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
9
doc/_templates/autosummary.md
vendored
Normal file
9
doc/_templates/autosummary.md
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
# {{ fullname | escape }}
|
||||||
|
|
||||||
|
```{automodule}
|
||||||
|
{{ fullname }}
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
:imported-members:
|
||||||
|
```
|
||||||
8
doc/_templates/autosummary/module.rst
vendored
Normal file
8
doc/_templates/autosummary/module.rst
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{{ fullname | escape | underline}}
|
||||||
|
|
||||||
|
.. automodule:: {{ fullname }}
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
:imported-members:
|
||||||
|
|
||||||
49
doc/conf.py
Normal file
49
doc/conf.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
#
|
||||||
|
# For the full list of built-in configuration values, see the documentation:
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||||
|
|
||||||
|
# -- Project information ------------------------------------------------------
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||||
|
|
||||||
|
project = "<package-name>"
|
||||||
|
copyright = "2025, Sam Griesemer"
|
||||||
|
author = "Sam Griesemer"
|
||||||
|
|
||||||
|
# -- General configuration ----------------------------------------------------
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||||
|
|
||||||
|
extensions = [
|
||||||
|
"sphinx.ext.autodoc",
|
||||||
|
# enables a directive to be specified manually that gathers module/object
|
||||||
|
# summary details in a table
|
||||||
|
"sphinx.ext.autosummary",
|
||||||
|
# allow viewing source in the HTML pages
|
||||||
|
"sphinx.ext.viewcode",
|
||||||
|
# only really applies to manual docs; docstrings still need RST-like
|
||||||
|
"myst_parser",
|
||||||
|
# enables Google-style docstring formats
|
||||||
|
"sphinx.ext.napoleon",
|
||||||
|
# external extension that allows arg types to be inferred by type hints
|
||||||
|
"sphinx_autodoc_typehints",
|
||||||
|
]
|
||||||
|
autosummary_generate = True
|
||||||
|
autosummary_imported_members = True
|
||||||
|
|
||||||
|
# include __init__ definitions in autodoc
|
||||||
|
autodoc_default_options = {
|
||||||
|
"special-members": "__init__",
|
||||||
|
}
|
||||||
|
|
||||||
|
templates_path = ["_templates"]
|
||||||
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
|
|
||||||
|
# -- Options for HTML output --------------------------------------------------
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||||
|
|
||||||
|
html_theme = "furo"
|
||||||
|
html_static_path = ["_static"]
|
||||||
|
# html_sidebars = {
|
||||||
|
# '**': ['/modules.html'],
|
||||||
|
# }
|
||||||
29
doc/index.md
Normal file
29
doc/index.md
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# `<project-name>` package docs
|
||||||
|
{ref}`genindex`
|
||||||
|
{ref}`modindex`
|
||||||
|
{ref}`search`
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. autosummary::
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
# list modules here for quick links
|
||||||
|
```
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 3
|
||||||
|
:caption: Autoref
|
||||||
|
|
||||||
|
_autoref/<project-name>.rst
|
||||||
|
```
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 3
|
||||||
|
:caption: Contents
|
||||||
|
|
||||||
|
reference/documentation/index
|
||||||
|
reference/site/index
|
||||||
|
```
|
||||||
|
|
||||||
|
```{include} ../README.md
|
||||||
|
```
|
||||||
35
doc/make.bat
Normal file
35
doc/make.bat
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
@ECHO OFF
|
||||||
|
|
||||||
|
pushd %~dp0
|
||||||
|
|
||||||
|
REM Command file for Sphinx documentation
|
||||||
|
|
||||||
|
if "%SPHINXBUILD%" == "" (
|
||||||
|
set SPHINXBUILD=sphinx-build
|
||||||
|
)
|
||||||
|
set SOURCEDIR=.
|
||||||
|
set BUILDDIR=_build
|
||||||
|
|
||||||
|
%SPHINXBUILD% >NUL 2>NUL
|
||||||
|
if errorlevel 9009 (
|
||||||
|
echo.
|
||||||
|
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||||
|
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||||
|
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||||
|
echo.may add the Sphinx directory to PATH.
|
||||||
|
echo.
|
||||||
|
echo.If you don't have Sphinx installed, grab it from
|
||||||
|
echo.https://www.sphinx-doc.org/
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if "%1" == "" goto help
|
||||||
|
|
||||||
|
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
goto end
|
||||||
|
|
||||||
|
:help
|
||||||
|
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
|
||||||
|
:end
|
||||||
|
popd
|
||||||
5
doc/reference/documentation/index.md
Normal file
5
doc/reference/documentation/index.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Documentation
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
sphinx
|
||||||
|
```
|
||||||
111
doc/reference/documentation/sphinx.md
Normal file
111
doc/reference/documentation/sphinx.md
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# Sphinx
|
||||||
|
The primary driver of this package's documentation is Sphinx's `autodoc` extension,
|
||||||
|
using the [Furo theme][1].
|
||||||
|
|
||||||
|
**High-level details**:
|
||||||
|
|
||||||
|
- `sphinx-apidoc` generates package-based documentation to the `_autoref/` directory,
|
||||||
|
with navigation available under "Autoref" in the sidebar.
|
||||||
|
- Markdown-based documentation files are manually written under the `reference/`
|
||||||
|
directory, showing up under "Contents" in the sidebar.
|
||||||
|
|
||||||
|
## Detailed directory structure
|
||||||
|
All files are placed under `docs/sphinx`:
|
||||||
|
|
||||||
|
- `_`-prefixed are Sphinx-managed directories
|
||||||
|
* `_build/html/` houses output HTML files
|
||||||
|
* `_autoref/` is the target for module-based RST files written by `autodoc`
|
||||||
|
- `reference/`: houses all manually written documentation (totally separate from
|
||||||
|
auto-generated package docs)
|
||||||
|
- `conf.py`: single Sphinx configuration file
|
||||||
|
- `index.md`: documentation index, setups up a persistent sidebar across all other pages
|
||||||
|
|
||||||
|
For manually written documentation under `reference/`, topics are nested as needed. Within
|
||||||
|
a nested directory `reference/<topic>`, an `index.md` should created with content like:
|
||||||
|
|
||||||
|
```
|
||||||
|
# <Topic>
|
||||||
|
|
||||||
|
\`\`\`{toctree}
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
sub-topic-1.rst
|
||||||
|
sub-topic-2.rst
|
||||||
|
...
|
||||||
|
\`\`\`
|
||||||
|
```
|
||||||
|
|
||||||
|
This will add the nested directory to the sidebar navigation, using the name set under the
|
||||||
|
top-level header. See [Markdown syntax][#markdown-syntax] for more details on the syntax.
|
||||||
|
|
||||||
|
## Sphinx autodoc
|
||||||
|
Sphinx's `autodoc` extension allows automatic generation of documents according to
|
||||||
|
(Python) subpackage structure and available docstrings. A few notes here:
|
||||||
|
|
||||||
|
- In the `conf.py` file, autodoc is enabled by adding `"sphinx.ext.autodoc"` to
|
||||||
|
the extensions list. `"sphinx.ext.viewcode"` can also be added to provide
|
||||||
|
links to source code.
|
||||||
|
- Documents are actually generated by calling the `sphinx-apidoc` CLI command. The
|
||||||
|
current Makefile uses the following call:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sphinx-apidoc --module-first -o docs/sphinx/_autoref/ localsys
|
||||||
|
```
|
||||||
|
|
||||||
|
This writes the automatically generated docs for modules in the package at the
|
||||||
|
local directory `localsys/` to the `docs/sphinx/_autoref` directory. These are
|
||||||
|
reStructuredText files by default.
|
||||||
|
* `--module-first` places the module-level descriptions at the top of the module page.
|
||||||
|
By default, this is placed at the bottom (oddly), and can be obscured by large lists
|
||||||
|
of subpackages if this flag isn't provided.
|
||||||
|
* See available `sphinx-apidoc` options [here][2], as well as more advanced config
|
||||||
|
[here][3].
|
||||||
|
|
||||||
|
|
||||||
|
## Markdown syntax
|
||||||
|
The `myst_parser` extension enables Markdown (or something close to it) to be used when
|
||||||
|
writing documentation files. The Sphinx directives can be difficult to track, and
|
||||||
|
they change slightly under the MyST Markdown syntax. The following are a few common
|
||||||
|
blocks:
|
||||||
|
|
||||||
|
**Page hierarchies**: the following will generate link hierarchy according to the provided
|
||||||
|
pages:
|
||||||
|
|
||||||
|
```
|
||||||
|
\`\`\`{toctree}
|
||||||
|
:maxdepth: <n>
|
||||||
|
:caption: <caption>
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
example-file-1
|
||||||
|
example-file-2
|
||||||
|
example-dir/index
|
||||||
|
...
|
||||||
|
\`\`\`
|
||||||
|
```
|
||||||
|
|
||||||
|
- `:maxdepth:` limits the depth of nesting
|
||||||
|
- `:caption:` title for the group of pages
|
||||||
|
- `:hidden:` if provided, links will only show in the sidebar (hidden on the page)
|
||||||
|
- Constituent files: listed files will be rendered as a link directly. If a listed file
|
||||||
|
has a `{toctree}` directive, this tree will be rendered in place of the page's link as a
|
||||||
|
dropdown. The dropdown will be named according to the file's top-level heading, and
|
||||||
|
clicking directly on the dropdown header will show that page's content. Files found in
|
||||||
|
the tree will be placed as links under the dropdown, recursively subject to same rules
|
||||||
|
described here.
|
||||||
|
|
||||||
|
**Include files**: the following will include file content
|
||||||
|
pages:
|
||||||
|
|
||||||
|
```
|
||||||
|
\`\`\`{include} README.md
|
||||||
|
\`\`\`
|
||||||
|
```
|
||||||
|
|
||||||
|
**Reference directives**
|
||||||
|
|
||||||
|
|
||||||
|
[1]: https://pradyunsg.me/furo/
|
||||||
|
[2]: https://www.sphinx-doc.org/en/master/man/sphinx-apidoc.html
|
||||||
|
[3]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#
|
||||||
|
|
||||||
20
example/dataset.py
Normal file
20
example/dataset.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from trainlib.domain import SequenceDomain
|
||||||
|
from trainlib.datasets.memory import TupleDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Record(NamedTuple):
|
||||||
|
a: int
|
||||||
|
b: str
|
||||||
|
|
||||||
|
tl_domain = SequenceDomain[Record]([
|
||||||
|
Record(1, "1"),
|
||||||
|
Record(2, "2"),
|
||||||
|
])
|
||||||
|
|
||||||
|
class R0(TupleDataset[Record]):
|
||||||
|
item_tuple = Record
|
||||||
|
|
||||||
|
def _process_item_data(self, item_data, item_index):
|
||||||
|
return (item_data[0],)
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "trainlib"
|
name = "trainlib"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Marginalizing out the modality layer:
|
.. admonition:: Marginalizing out the modality layer
|
||||||
|
|
||||||
With ``domain`` being an instance variable, one possible interpretation of
|
With ``domain`` being an instance variable, one possible interpretation of
|
||||||
the object structures here is that one could completely abstract away
|
the object structures here is that one could completely abstract away
|
||||||
@@ -58,64 +58,71 @@ Marginalizing out the modality layer:
|
|||||||
particular case of ``_process_batch_data()``, it feels much better when
|
particular case of ``_process_batch_data()``, it feels much better when
|
||||||
it's on the inside.)
|
it's on the inside.)
|
||||||
|
|
||||||
Holding:
|
.. admonition:: Holding area
|
||||||
@abstractmethod
|
|
||||||
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
|
|
||||||
Get URI groups for each batch.
|
|
||||||
|
|
||||||
If there's more than one URI per batch (e.g., a data file and a
|
.. code-block:: python
|
||||||
metadata file), zip the URIs such that we have a tuple of URIs per
|
|
||||||
batch.
|
|
||||||
|
|
||||||
Note that this effectively defines the index style over batches in the
|
@abstractmethod
|
||||||
attached domain. We get an ``int -> tuple[U, ...]`` map that turns
|
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
|
||||||
batch indices into URIs that can be read under the domain.
|
Get URI groups for each batch.
|
||||||
``get_batch()`` turns an integer index into its corresponding
|
|
||||||
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
|
|
||||||
the tuple, treating them as providers of batched data.
|
|
||||||
``_read_resources()`` passes through to the attached domain logic,
|
|
||||||
which, although common, need not supply an explicit iterable of batch
|
|
||||||
items: we just access items with ``__getitem__()`` and may ask for
|
|
||||||
``__len__``. So the returned URI group collection (this method) does
|
|
||||||
need to be iterable to measure the number of batches, but the batch
|
|
||||||
objects that are ultimately produced by these URI groups need not be
|
|
||||||
iterables themselves.
|
|
||||||
|
|
||||||
raise NotImplementedError
|
If there's more than one URI per batch (e.g., a data file and a
|
||||||
|
metadata file), zip the URIs such that we have a tuple of URIs per
|
||||||
|
batch.
|
||||||
|
|
||||||
def _read_resources(
|
Note that this effectively defines the index style over batches in
|
||||||
self,
|
the attached domain. We get an ``int -> tuple[U, ...]`` map that
|
||||||
uri_group: tuple[U, ...],
|
turns batch indices into URIs that can be read under the domain.
|
||||||
batch_index: int
|
``get_batch()`` turns an integer index into its corresponding
|
||||||
) -> tuple[R, ...]:
|
``tuple[U, ...]``, reading the resources with ``_read_resources()``
|
||||||
Read batch files at the provided paths.
|
in the tuple, treating them as providers of batched data.
|
||||||
|
``_read_resources()`` passes through to the attached domain logic,
|
||||||
|
which, although common, need not supply an explicit iterable of
|
||||||
|
batch items: we just access items with ``__getitem__()`` and may
|
||||||
|
ask for ``__len__``. So the returned URI group collection (this
|
||||||
|
method) does need to be iterable to measure the number of batches,
|
||||||
|
but the batch objects that are ultimately produced by these URI
|
||||||
|
groups need not be iterables themselves.
|
||||||
|
|
||||||
This method should operate on a single tuple from the list of batch
|
raise NotImplementedError
|
||||||
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
|
|
||||||
all of the resources for a single batch and returns a tuple of the same
|
|
||||||
size with their contents.
|
|
||||||
|
|
||||||
Note: the dependence on a batch index is mostly here to make
|
def _read_resources(
|
||||||
multi-dataset composition easier later. In-dataset, you don't need to
|
self,
|
||||||
know the batch index to to simply process URIs, but across datasets you
|
uri_group: tuple[U, ...],
|
||||||
need it to find out the origin of the batch (and process those URIs
|
batch_index: int
|
||||||
accordingly).
|
) -> tuple[R, ...]:
|
||||||
|
Read batch files at the provided paths.
|
||||||
|
|
||||||
return tuple(self.domain.read(uri) for uri in uri_group)
|
This method should operate on a single tuple from the list of batch
|
||||||
|
tuples returned by the ``_get_uri_groups()`` method. That is, it
|
||||||
|
reads all of the resources for a single batch and returns a tuple
|
||||||
|
of the same size with their contents.
|
||||||
|
|
||||||
# pulling the type variable out of the inline generic b/c `ty` has trouble
|
Note: the dependence on a batch index is mostly here to make
|
||||||
# understanding bound type variables in subclasses (specifically with Self@)
|
multi-dataset composition easier later. In-dataset, you don't need
|
||||||
T = TypeVar("T", bound=NamedTuple)
|
to know the batch index to to simply process URIs, but across
|
||||||
|
datasets you need it to find out the origin of the batch (and
|
||||||
|
process those URIs accordingly).
|
||||||
|
|
||||||
class NamedTupleDataset[I](Dataset):
|
return tuple(self.domain.read(uri) for uri in uri_group)
|
||||||
def __init__(self, data_list: list[I]) -> None:
|
|
||||||
self.data_list = data_list
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
.. code-block:: python
|
||||||
return len(self.data_list)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> I:
|
# pulling the type variable out of the inline generic b/c `ty` has
|
||||||
return self.data_list[index]
|
# trouble understanding bound type variables in subclasses
|
||||||
|
# (specifically with Self@)
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=NamedTuple)
|
||||||
|
|
||||||
|
class NamedTupleDataset[I](Dataset):
|
||||||
|
def __init__(self, data_list: list[I]) -> None:
|
||||||
|
self.data_list = data_list
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.data_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> I:
|
||||||
|
return self.data_list[index]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -156,38 +163,41 @@ class BatchedDataset[U, R, I](Dataset):
|
|||||||
which are used to concretize a domain ``Domain[U, R]``), and an item type
|
which are used to concretize a domain ``Domain[U, R]``), and an item type
|
||||||
``T`` (which has a ``tuple`` upper bound).
|
``T`` (which has a ``tuple`` upper bound).
|
||||||
|
|
||||||
Pipeline overview:
|
|
||||||
|
.. admonition:: Pipeline overview
|
||||||
|
|
||||||
```
|
.. code-block:: python
|
||||||
Domain -> [U] (get _batch_uris)
|
|
||||||
U -> R (domain access ; Rs provide batches)
|
Domain -> [U] (get _batch_uris)
|
||||||
R -> [I] (cache here ; _process_batch_data to use load_transform)
|
U -> R (domain access ; Rs provide batches)
|
||||||
[I] -> I (human item obj ; _get_item)
|
R -> [I] (cache here ; _process_batch_data to use load_transform)
|
||||||
I -> **P (final packed item ; __getitem__ to use transform)
|
[I] -> I (human item obj ; _get_item)
|
||||||
```
|
I -> **P (final packed item ; __getitem__ to use transform)
|
||||||
|
|
||||||
Note^1: as far as positioning, this class is meant to play nice with
|
Note^1: as far as positioning, this class is meant to play nice with
|
||||||
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
|
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
|
||||||
value add for this over the ``torch.Dataset`` base is almost entirely in
|
value add for this over the ``torch.Dataset`` base is almost entirely
|
||||||
the logic it implements to map out of *batched resources* that are holding
|
in the logic it implements to map out of *batched resources* that are
|
||||||
data, and flattening it out into typical dataset items. There are also some
|
holding data, and flattening it out into typical dataset items. There
|
||||||
QoL items when it comes to splitting and balancing samples.
|
are also some QoL items when it comes to splitting and balancing
|
||||||
|
samples.
|
||||||
|
|
||||||
Note^2: even though ``Domains`` implement iterators over their URIs, this
|
Note^2: even though ``Domains`` implement iterators over their URIs,
|
||||||
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk
|
this doesn't imply a ``BatchedDataset`` is iterable. This just means we
|
||||||
over the resources that provide data, but we don't necessarily presuppose
|
can walk over the resources that provide data, but we don't necessarily
|
||||||
an ordered walk over samples within batches. Point being:
|
presuppose an ordered walk over samples within batches. Point being:
|
||||||
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
||||||
superclass, even when we're working around iterable ``Domains``.
|
superclass, even when we're working around iterable ``Domains``.
|
||||||
|
|
||||||
Note^3: transforms are expected to operate on ``I``-items and produce
|
Note^3: transforms are expected to operate on ``I``-items and produce
|
||||||
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
|
``I``-items. They shouldn't be the "introducers" of ``I`` types from
|
||||||
other intermediate representation, nor should they map from ``I`` to
|
some other intermediate representation, nor should they map from ``I``
|
||||||
something else. Point being: the dataset definition should be able to map
|
to something else. Point being: the dataset definition should be able
|
||||||
resources ``R`` to ``I`` without a transform: that much should be baked
|
to map resources ``R`` to ``I`` without a transform: that much should
|
||||||
into the class definition. If you find you're expecting the transform to do
|
be baked into the class definition. If you find you're expecting the
|
||||||
that for you, you should consider pulling in some common structure across
|
transform to do that for you, you should consider pulling in some
|
||||||
the allowed transforms and make it a fixed part of the class.
|
common structure across the allowed transforms and make it a fixed part
|
||||||
|
of the class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from torch import nn, Tensor
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from trainlib.util.type import OptimizerKwargs
|
from trainlib.utils.type import OptimizerKwargs
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from collections.abc import Callable
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch import cuda, Tensor
|
from torch import cuda, Tensor
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@@ -98,6 +99,151 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
self._stagnant_epochs = 0
|
self._stagnant_epochs = 0
|
||||||
self._best_model_state_dict: dict[str, Any] = {}
|
self._best_model_state_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def _train_epoch(
|
||||||
|
self,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
batch_estimator_map: Callable[[I, Self], K],
|
||||||
|
optimizers: tuple[Optimizer, ...],
|
||||||
|
writer: SummaryWriter,
|
||||||
|
max_grad_norm: float | None = None,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Train the estimator for a single epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
train_loss_sums = []
|
||||||
|
self.estimator.train()
|
||||||
|
with tqdm(train_loader, unit="batch") as train_epoch:
|
||||||
|
for i, batch_data in enumerate(train_epoch):
|
||||||
|
est_kwargs = batch_estimator_map(batch_data, self)
|
||||||
|
inputs = est_kwargs["inputs"]
|
||||||
|
|
||||||
|
# one-time logging
|
||||||
|
if self._step == 0:
|
||||||
|
writer.add_graph(
|
||||||
|
ModelWrapper(self.estimator),
|
||||||
|
est_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# once-per-epoch logging
|
||||||
|
if i == 0:
|
||||||
|
self.estimator.epoch_write(
|
||||||
|
writer,
|
||||||
|
step=self._step,
|
||||||
|
val=False,
|
||||||
|
**est_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
train_losses = self.estimator.loss(**est_kwargs)
|
||||||
|
train_loss_items = []
|
||||||
|
for o_idx, optimizer in enumerate(optimizers):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
train_loss = next(train_losses)
|
||||||
|
|
||||||
|
if len(train_loss_sums) <= o_idx:
|
||||||
|
train_loss_sums.append(0.0)
|
||||||
|
|
||||||
|
train_loss_item = train_loss.item()
|
||||||
|
train_loss_sums[o_idx] += train_loss_item
|
||||||
|
train_loss_items.append(train_loss_item)
|
||||||
|
|
||||||
|
train_loss.backward()
|
||||||
|
|
||||||
|
# clip gradients for optimizer's parameters
|
||||||
|
if max_grad_norm is not None:
|
||||||
|
clip_grad_norm_(
|
||||||
|
self._get_optimizer_parameters(optimizer),
|
||||||
|
max_norm=max_grad_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
self._step += len(inputs)
|
||||||
|
|
||||||
|
for train_loss_item, train_loss_sum in zip(
|
||||||
|
train_loss_items,
|
||||||
|
train_loss_sums,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
||||||
|
self._add_summary_item("train_loss", train_loss_item)
|
||||||
|
|
||||||
|
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||||
|
for metric_name, metric_value in estimator_metrics.items():
|
||||||
|
self._add_summary_item(
|
||||||
|
f"train_{metric_name}",
|
||||||
|
metric_value
|
||||||
|
)
|
||||||
|
|
||||||
|
self.estimator.epoch_step()
|
||||||
|
|
||||||
|
for li, train_loss_sum in enumerate(train_loss_sums):
|
||||||
|
self._add_summary_item(
|
||||||
|
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loss_sums
|
||||||
|
|
||||||
|
def _val_epoch(
|
||||||
|
self,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
batch_estimator_map: Callable[[I, Self], K],
|
||||||
|
optimizers: tuple[Optimizer, ...],
|
||||||
|
writer: SummaryWriter,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Perform and record validation scores for a single epoch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
val_loss_sums = []
|
||||||
|
self.estimator.eval()
|
||||||
|
with tqdm(val_loader, unit="batch") as val_epoch:
|
||||||
|
for i, batch_data in enumerate(val_epoch):
|
||||||
|
est_kwargs = batch_estimator_map(batch_data, self)
|
||||||
|
|
||||||
|
# once-per-epoch logging
|
||||||
|
if i == 0:
|
||||||
|
self.estimator.epoch_write(
|
||||||
|
writer,
|
||||||
|
step=self._step,
|
||||||
|
val=True,
|
||||||
|
**est_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
val_losses = self.estimator.loss(**est_kwargs)
|
||||||
|
val_loss_items = []
|
||||||
|
for o_idx in range(len(optimizers)):
|
||||||
|
val_loss = next(val_losses)
|
||||||
|
|
||||||
|
if len(val_loss_sums) <= o_idx:
|
||||||
|
val_loss_sums.append(0.0)
|
||||||
|
|
||||||
|
val_loss_item = val_loss.item()
|
||||||
|
val_loss_sums[o_idx] += val_loss_item
|
||||||
|
val_loss_items.append(val_loss_item)
|
||||||
|
|
||||||
|
for val_loss_item, val_loss_sum in zip(
|
||||||
|
val_loss_items,
|
||||||
|
val_loss_sums,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
||||||
|
self._add_summary_item("val_loss", val_loss_item)
|
||||||
|
|
||||||
|
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||||
|
for metric_name, metric_value in estimator_metrics.items():
|
||||||
|
self._add_summary_item(f"val_{metric_name}", metric_value)
|
||||||
|
|
||||||
|
for li, val_loss_sum in enumerate(val_loss_sums):
|
||||||
|
self._add_summary_item(
|
||||||
|
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
||||||
|
)
|
||||||
|
|
||||||
|
# convergence of multiple losses may be ambiguous
|
||||||
|
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||||
|
|
||||||
|
return val_loss_sums
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
dataset: BatchedDataset[..., ..., I],
|
dataset: BatchedDataset[..., ..., I],
|
||||||
@@ -122,40 +268,45 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
"""
|
"""
|
||||||
Note: this method attempts to implement a general scheme for passing
|
Note: this method attempts to implement a general scheme for passing
|
||||||
needed items to the estimator's loss function from the dataloader. The
|
needed items to the estimator's loss function from the dataloader. The
|
||||||
abstract `Estimator` base only requires the model output be provided
|
abstract ``Estimator`` base only requires the model output be provided
|
||||||
for any given loss calculation, but concrete estimators will often
|
for any given loss calculation, but concrete estimators will often
|
||||||
require additional arguments (e.g., labels or length masks, as
|
require additional arguments (e.g., labels or length masks, as is the
|
||||||
is the case with sequential models). In any case, this method defers
|
case with sequential models). In any case, this method defers any
|
||||||
any further logic to the `loss` method of the underlying estimator, so
|
further logic to the ``loss`` method of the underlying estimator, so
|
||||||
one should take care to synchronize the sample structure with `dataset`
|
one should take care to synchronize the sample structure with `dataset`
|
||||||
to match that expected by `self.estimator.loss(...)`.
|
to match that expected by ``self.estimator.loss(...)``.
|
||||||
|
|
||||||
|
.. admonition:: On batch_estimator_map
|
||||||
|
|
||||||
On batch_estimator_map:
|
Dataloader collate functions are responsible for mapping a
|
||||||
|
collection of items into an item of collections, roughly speaking.
|
||||||
|
If items are tuples of tensors,
|
||||||
|
|
||||||
Dataloader collate functions are responsible for mapping a collection
|
.. code-block::
|
||||||
of items into an item of collections, roughly speaking. If items are
|
|
||||||
tuples of tensors,
|
|
||||||
|
|
||||||
[
|
[
|
||||||
( [1, 1], [1, 1] ),
|
( [1, 1], [1, 1] ),
|
||||||
( [2, 2], [2, 2] ),
|
( [2, 2], [2, 2] ),
|
||||||
( [3, 3], [3, 3] ),
|
( [3, 3], [3, 3] ),
|
||||||
]
|
]
|
||||||
|
|
||||||
the collate function maps back into the item skeleton, producing a
|
the collate function maps back into the item skeleton, producing a
|
||||||
single tuple of (stacked) tensors
|
single tuple of (stacked) tensors
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
( [[1, 1],
|
( [[1, 1],
|
||||||
[2, 2],
|
[2, 2],
|
||||||
[3, 3]],
|
[3, 3]],
|
||||||
|
|
||||||
[[1, 1],
|
[[1, 1],
|
||||||
[2, 2],
|
[2, 2],
|
||||||
[3, 3]] )
|
[3, 3]] )
|
||||||
|
|
||||||
This function should map from batches (which should be *item shaped*,
|
This function should map from batches (which should be *item
|
||||||
i.e., have an `I` skeleton, even if stacked items may be different on
|
shaped*, i.e., have an ``I`` skeleton, even if stacked items may be
|
||||||
the inside) into estimator keyword arguments (type `K`).
|
different on the inside) into estimator keyword arguments (type
|
||||||
|
``K``).
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
lr: learning rate (default: 1e-3)
|
lr: learning rate (default: 1e-3)
|
||||||
@@ -212,123 +363,32 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
while self._epoch <= max_epochs and not self._converged(
|
while self._epoch <= max_epochs and not self._converged(
|
||||||
self._epoch, stop_after_epochs
|
self._epoch, stop_after_epochs
|
||||||
):
|
):
|
||||||
print(f"Training epoch {self._epoch}/{max_epochs}...")
|
train_frac = f"{self._epoch}/{max_epochs}"
|
||||||
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...")
|
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||||
|
print(f"Training epoch {train_frac}...")
|
||||||
|
print(f"Stagnant epochs {stag_frac}...")
|
||||||
|
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
train_loss_sums = []
|
self._train_epoch(
|
||||||
self.estimator.train()
|
train_loader,
|
||||||
with tqdm(train_loader, unit="batch") as train_epoch:
|
batch_estimator_map,
|
||||||
for i, batch_data in enumerate(train_epoch):
|
optimizers,
|
||||||
est_kwargs = batch_estimator_map(batch_data, self)
|
writer,
|
||||||
inputs = est_kwargs["inputs"]
|
max_grad_norm
|
||||||
|
)
|
||||||
# one-time logging
|
|
||||||
if self._step == 0:
|
|
||||||
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
|
|
||||||
|
|
||||||
# once-per-epoch logging
|
|
||||||
if i == 0:
|
|
||||||
self.estimator.epoch_write(
|
|
||||||
writer,
|
|
||||||
step=self._step,
|
|
||||||
val=False,
|
|
||||||
**est_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
train_losses = self.estimator.loss(**est_kwargs)
|
|
||||||
train_loss_items = []
|
|
||||||
for o_idx, optimizer in enumerate(optimizers):
|
|
||||||
optimizer.zero_grad()
|
|
||||||
train_loss = next(train_losses)
|
|
||||||
|
|
||||||
if len(train_loss_sums) <= o_idx:
|
|
||||||
train_loss_sums.append(0.0)
|
|
||||||
|
|
||||||
train_loss_item = train_loss.item()
|
|
||||||
train_loss_sums[o_idx] += train_loss_item
|
|
||||||
train_loss_items.append(train_loss_item)
|
|
||||||
|
|
||||||
train_loss.backward()
|
|
||||||
|
|
||||||
# clip gradients for optimizer's parameters
|
|
||||||
if max_grad_norm is not None:
|
|
||||||
opt_params = self._get_optimizer_parameters(optimizer)
|
|
||||||
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
self._step += len(inputs)
|
|
||||||
|
|
||||||
for train_loss_item, train_loss_sum in zip(
|
|
||||||
train_loss_items,
|
|
||||||
train_loss_sums,
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
|
||||||
self._add_summary_item("train_loss", train_loss_item)
|
|
||||||
|
|
||||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
|
||||||
for metric_name, metric_value in estimator_metrics.items():
|
|
||||||
self._add_summary_item(f"train_{metric_name}", metric_value)
|
|
||||||
|
|
||||||
self.estimator.epoch_step()
|
|
||||||
|
|
||||||
for li, train_loss_sum in enumerate(train_loss_sums):
|
|
||||||
self._add_summary_item(
|
|
||||||
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
|
||||||
)
|
|
||||||
|
|
||||||
if val_frac > 0:
|
if val_frac > 0:
|
||||||
val_loss_sums = []
|
self._val_epoch(
|
||||||
self.estimator.eval()
|
val_loader,
|
||||||
with tqdm(val_loader, unit="batch") as val_epoch:
|
batch_estimator_map,
|
||||||
for i, batch_data in enumerate(val_epoch):
|
optimizers,
|
||||||
est_kwargs = batch_estimator_map(batch_data, self)
|
writer,
|
||||||
inputs = est_kwargs["inputs"]
|
)
|
||||||
|
|
||||||
# once-per-epoch logging
|
self._add_summary_item(
|
||||||
if i == 0:
|
"epoch_time_sec",
|
||||||
self.estimator.epoch_write(
|
time.time() - epoch_start_time
|
||||||
writer,
|
)
|
||||||
step=self._step,
|
|
||||||
val=True,
|
|
||||||
**est_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
val_losses = self.estimator.loss(**est_kwargs)
|
|
||||||
val_loss_items = []
|
|
||||||
for o_idx in range(len(optimizers)):
|
|
||||||
val_loss = next(val_losses)
|
|
||||||
|
|
||||||
if len(val_loss_sums) <= o_idx:
|
|
||||||
val_loss_sums.append(0.0)
|
|
||||||
|
|
||||||
val_loss_item = val_loss.item()
|
|
||||||
val_loss_sums[o_idx] += val_loss_item
|
|
||||||
val_loss_items.append(val_loss_item)
|
|
||||||
|
|
||||||
for val_loss_item, val_loss_sum in zip(
|
|
||||||
val_loss_items,
|
|
||||||
val_loss_sums,
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
|
||||||
self._add_summary_item("val_loss", val_loss_item)
|
|
||||||
|
|
||||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
|
||||||
for metric_name, metric_value in estimator_metrics.items():
|
|
||||||
self._add_summary_item(f"val_{metric_name}", metric_value)
|
|
||||||
|
|
||||||
for li, val_loss_sum in enumerate(val_loss_sums):
|
|
||||||
self._add_summary_item(
|
|
||||||
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
|
||||||
)
|
|
||||||
|
|
||||||
# convergence of multiple losses may be ambiguous
|
|
||||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
|
||||||
|
|
||||||
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
|
|
||||||
|
|
||||||
if self._epoch % summarize_every == 0:
|
if self._epoch % summarize_every == 0:
|
||||||
self._summarize(writer, self._epoch)
|
self._summarize(writer, self._epoch)
|
||||||
@@ -336,7 +396,9 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
# save checkpoint
|
# save checkpoint
|
||||||
if self._epoch % chkpt_every == 0:
|
if self._epoch % chkpt_every == 0:
|
||||||
self.save_model(
|
self.save_model(
|
||||||
self._epoch, self.chkpt_dir, dir_prefix
|
self._epoch,
|
||||||
|
self.chkpt_dir,
|
||||||
|
dir_prefix
|
||||||
)
|
)
|
||||||
|
|
||||||
self._epoch += 1
|
self._epoch += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user