Compare commits

...

4 Commits

Author SHA1 Message Date
c2e4294c8c update module docs and sphinx config 2026-03-07 19:46:00 -08:00
e867bc0e7f add plot styles, clean up package-wide docstrings 2026-03-07 03:10:13 -08:00
faeef9c72a reformat docstrings for sphinx 2026-03-05 01:36:40 -08:00
805262dfc4 refactor Trainer train/val loop 2026-03-05 01:14:24 -08:00
23 changed files with 1051 additions and 420 deletions

View File

@@ -1,17 +1,12 @@
# Overview # Overview
Package summary goes here, ideally with a diagram Minimal framework for ML modeling, supporting advanced dataset operations and
streamlined training workflows.
# Install # Install
Installation instructions The `trainlib` package can be installed from PyPI:
```sh ```sh
pip install <package> pip install trainlib
```
or as a CLI tool
```sh
uv tool install <package>
``` ```
# Development # Development
@@ -20,16 +15,16 @@ uv tool install <package>
- Depending on needs, install the development dependencies with `uv sync - Depending on needs, install the development dependencies with `uv sync
--extra dev`. --extra dev`.
# Testing ## Testing
- To run the unit tests, make sure to first have the test dependencies - To run the unit tests, make sure to first have the test dependencies
installed with `uv sync --extra test`, then run `make test`. installed with `uv sync --extra test`, then run `make test`.
- For notebook testing, run `make install-kernel` to make the environment - For notebook testing, run `make install-kernel` to make the environment
available as a Jupyter kernel (to be selected when running notebooks). available as a Jupyter kernel (to be selected when running notebooks).
# Documentation ## Documentation
- Install the documentation dependencies with `uv sync --extra doc`. - Install the documentation dependencies with `uv sync --extra doc`.
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve - Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
locally with `docs-serve`. locally with `make docs-serve`.
# Development remarks # Development remarks
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a - Across `Trainer` / `Estimator` / `Dataset`, I've considered a
@@ -91,7 +86,7 @@ uv tool install <package>
class SequenceDataset[I, **P](HomogenousDataset[int, I, I, P]): class SequenceDataset[I, **P](HomogenousDataset[int, I, I, P]):
... ...
class TupleDataset[I](SequenceDataset[tuple[I, ...], ??]): class TupleDataset[I](SequenceDataset[tuple[I, ...], "?"]):
... ...
``` ```

20
doc/Makefile Normal file
View File

@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

9
doc/_templates/autosummary.md vendored Normal file
View File

@@ -0,0 +1,9 @@
# {{ fullname | escape }}
```{automodule}
{{ fullname }}
:members:
:undoc-members:
:show-inheritance:
:imported-members:
```

8
doc/_templates/autosummary/module.rst vendored Normal file
View File

@@ -0,0 +1,8 @@
{{ fullname | escape | underline}}
.. automodule:: {{ fullname }}
:members:
:undoc-members:
:show-inheritance:
:imported-members:

112
doc/conf.py Normal file
View File

@@ -0,0 +1,112 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Styling: type hints ------------------------------------------------------
# There are several possible style combinations for rendering types, none of
# which are optimal in my view. The main switches are:
#
# - Parameter type hints in the signature vs in the separate parameter list
# - Show type hints as plaintext vs rendered HTML elements
#
# The `sphinx_autodoc_typehints` extension enables more context-aware
# rendering, but it's often way too explicit (e.g., unwrapping type variables)
# and makes things difficult to read. It does, however, allow for automatic
# inclusion of default values, which is nice.
#
# I'd like type hints to be rendered in an inline code element, but that
# doesn't happen by default in either case unless you render them in the
# signature. This is sloppy, however, often just a jumbled mess or parameter
# names and types. The current preferred option is to just use the native
# `autodoc` settings for rendering type hints, leaving them out of the
# signature (for easy heading readability). Type hints in the parameter list
# are also as short as possible, not rendered crazily (by default this is in
# italics; not my favorite but it's what we have). No
# `sphinx_autodoc_typehints` needed at this point; you can toggle it if you
# want automatic default values or different formatting for type hints.
# -- Project information ------------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "trainlib"
copyright = "2026, Sam Griesemer"
author = "Sam Griesemer"
# -- General configuration ----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"sphinx.ext.autodoc",
# enables a directive to be specified manually that gathers module/object
# summary details in a table
"sphinx.ext.autosummary",
# allow viewing source in the HTML pages
"sphinx.ext.viewcode",
# only really applies to manual docs; docstrings still need RST-like
"myst_parser",
# enables Google-style docstring formats
"sphinx.ext.napoleon",
# external extension that allows arg types to be inferred by type hints;
# without this, type hints show up inside method signatures as plaintext,
# but when enabled they are pulled into the parameter/description block and
# rendered as native nested markup. What's best for a given package may
# vary.
# "sphinx_autodoc_typehints",
]
autosummary_generate = True
autosummary_imported_members = True
# include __init__ definitions in autodoc
autodoc_default_options = {
"special-members": "__init__",
}
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for autodoc ------------------------------------------------------
# class signatures show up only in __init__ rather than at the class header;
# generally cleaner, avoids redundancy
autodoc_class_signature = "separated"
# if `sphinx_autodoc_typehints` extension is enabled, this is redundant: type
# hints are rendered natively and already show up in the parameter block. If
# it's disabled, this setting will do the same job of moving the types to the
# parameter block, but it renders them in plaintext (with links to in-package
# type refs).
autodoc_typehints = "description" # "signature"
autodoc_typehints_format = "short"
autodoc_preserve_defaults = True
autodoc_use_type_comments = False
python_use_unqualified_type_names = True
# push parameters to their own lines in the signature block
# python_maximum_signature_line_length = 60
# -- Options for autodoc_typehints --------------------------------------------
# always_use_bars_union = True # always on for Python 3.14+
# typehints_defaults = "braces-after" # render defaults in param block
# typehints_use_signature = False # False is default; enable if wanted in sig
# always_document_param_types = True # show types even when not in docstring
# -- Options for HTML output --------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "furo" # "pydata_sphinx_theme"
html_static_path = ["_static"]
# html_sidebars = {
# '**': ['/modules.html'],
# }

37
doc/index.md Normal file
View File

@@ -0,0 +1,37 @@
# `trainlib` package docs
{ref}`genindex`
{ref}`modindex`
```{eval-rst}
.. autosummary::
:nosignatures:
:recursive:
:caption: Modules
trainlib.dataset
trainlib.domain
trainlib.estimator
trainlib.trainer
trainlib.transform
```
```{toctree}
:maxdepth: 3
:caption: Autoref
:hidden:
_autoref/trainlib.rst
```
```{toctree}
:maxdepth: 3
:caption: Contents
:hidden:
reference/documentation/index
```
```{include} ../README.md
:heading-offset: 1
```

35
doc/make.bat Normal file
View File

@@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

View File

@@ -0,0 +1,5 @@
# Documentation
```{toctree}
sphinx
```

View File

@@ -0,0 +1,111 @@
# Sphinx
The primary driver of this package's documentation is Sphinx's `autodoc` extension,
using the [Furo theme][1].
**High-level details**:
- `sphinx-apidoc` generates package-based documentation to the `_autoref/` directory,
with navigation available under "Autoref" in the sidebar.
- Markdown-based documentation files are manually written under the `reference/`
directory, showing up under "Contents" in the sidebar.
## Detailed directory structure
All files are placed under `docs/sphinx`:
- `_`-prefixed are Sphinx-managed directories
* `_build/html/` houses output HTML files
* `_autoref/` is the target for module-based RST files written by `autodoc`
- `reference/`: houses all manually written documentation (totally separate from
auto-generated package docs)
- `conf.py`: single Sphinx configuration file
- `index.md`: documentation index, setups up a persistent sidebar across all other pages
For manually written documentation under `reference/`, topics are nested as needed. Within
a nested directory `reference/<topic>`, an `index.md` should created with content like:
```
# <Topic>
\`\`\`{toctree}
:hidden:
sub-topic-1.rst
sub-topic-2.rst
...
\`\`\`
```
This will add the nested directory to the sidebar navigation, using the name set under the
top-level header. See [Markdown syntax][#markdown-syntax] for more details on the syntax.
## Sphinx autodoc
Sphinx's `autodoc` extension allows automatic generation of documents according to
(Python) subpackage structure and available docstrings. A few notes here:
- In the `conf.py` file, autodoc is enabled by adding `"sphinx.ext.autodoc"` to
the extensions list. `"sphinx.ext.viewcode"` can also be added to provide
links to source code.
- Documents are actually generated by calling the `sphinx-apidoc` CLI command. The
current Makefile uses the following call:
```sh
sphinx-apidoc --module-first -o docs/sphinx/_autoref/ localsys
```
This writes the automatically generated docs for modules in the package at the
local directory `localsys/` to the `docs/sphinx/_autoref` directory. These are
reStructuredText files by default.
* `--module-first` places the module-level descriptions at the top of the module page.
By default, this is placed at the bottom (oddly), and can be obscured by large lists
of subpackages if this flag isn't provided.
* See available `sphinx-apidoc` options [here][2], as well as more advanced config
[here][3].
## Markdown syntax
The `myst_parser` extension enables Markdown (or something close to it) to be used when
writing documentation files. The Sphinx directives can be difficult to track, and
they change slightly under the MyST Markdown syntax. The following are a few common
blocks:
**Page hierarchies**: the following will generate link hierarchy according to the provided
pages:
```
\`\`\`{toctree}
:maxdepth: <n>
:caption: <caption>
:hidden:
example-file-1
example-file-2
example-dir/index
...
\`\`\`
```
- `:maxdepth:` limits the depth of nesting
- `:caption:` title for the group of pages
- `:hidden:` if provided, links will only show in the sidebar (hidden on the page)
- Constituent files: listed files will be rendered as a link directly. If a listed file
has a `{toctree}` directive, this tree will be rendered in place of the page's link as a
dropdown. The dropdown will be named according to the file's top-level heading, and
clicking directly on the dropdown header will show that page's content. Files found in
the tree will be placed as links under the dropdown, recursively subject to same rules
described here.
**Include files**: the following will include file content
pages:
```
\`\`\`{include} README.md
\`\`\`
```
**Reference directives**
[1]: https://pradyunsg.me/furo/
[2]: https://www.sphinx-doc.org/en/master/man/sphinx-apidoc.html
[3]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#

20
example/dataset.py Normal file
View File

@@ -0,0 +1,20 @@
from typing import NamedTuple
from trainlib.domain import SequenceDomain
from trainlib.datasets.memory import TupleDataset
class Record(NamedTuple):
a: int
b: str
tl_domain = SequenceDomain[Record]([
Record(1, "1"),
Record(2, "2"),
])
class R0(TupleDataset[Record]):
item_tuple = Record
def _process_item_data(self, item_data, item_index):
return (item_data[0],)

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "trainlib" name = "trainlib"
version = "0.1.0" version = "0.1.2"
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training." description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
requires-python = ">=3.13" requires-python = ">=3.13"
authors = [ authors = [
@@ -24,11 +24,11 @@ classifiers = [
"Intended Audience :: End Users/Desktop", "Intended Audience :: End Users/Desktop",
] ]
dependencies = [ dependencies = [
"torch",
"colorama>=0.4.6", "colorama>=0.4.6",
"matplotlib>=3.10.8", "matplotlib>=3.10.8",
"numpy>=2.4.1", "numpy>=2.4.1",
"tensorboard>=2.20.0", "tensorboard>=2.20.0",
"torch>=2.5.1",
"tqdm>=4.67.1", "tqdm>=4.67.1",
] ]
@@ -41,6 +41,7 @@ dev = [
] ]
doc = [ doc = [
"furo", "furo",
# "pydata-sphinx-theme",
"myst-parser", "myst-parser",
"sphinx", "sphinx",
"sphinx-togglebutton", "sphinx-togglebutton",
@@ -82,3 +83,11 @@ force-sort-within-sections = false
quote-style = "double" quote-style = "double"
indent-style = "space" indent-style = "space"
docstring-code-format = true docstring-code-format = true
[tool.uv.sources]
torch = { index = "pytorch" }
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

View File

@@ -1,121 +1,62 @@
""" """
Marginalizing out the modality layer: Domain-generic dataset base with attribute-based splitting and balancing
With ``domain`` being an instance variable, one possible interpretation of **Marginalizing out the modality layer**
the object structures here is that one could completely abstract away
the domain model, defining only item structures and processing data. You
could have a single dataset definition for a particular concrete dataset,
and so long as we're talking about the same items, it can be instantiated
using *any domain*. You wouldn't need specific subclasses for disk or
network or in-memory; you can tell it directly at runtime.
That's an eventually possibility, anyway. As it stands, however, this is With ``domain`` being an instance variable, one possible interpretation of
effectively impossible: the object structures here is that one could completely abstract away
the domain model, defining only item structures and processing data. You
could have a single dataset definition for a particular concrete dataset,
and so long as we're talking about the same items, it can be instantiated
using *any domain*. You wouldn't need specific subclasses for disk or
network or in-memory structures; you can tell it directly at runtime.
You can't easily abstract the batch -> item splitting process, i.e., That's an eventually possibility, anyway. As it stands, however, this is
``_process_batch_data()``. A list-based version of the dataset you're effectively impossible:
trying to define might have an individual item tuple at every index,
whereas a disk-based version might have tuples batched across a few files.
This can't reliably be inferred, nor can it be pushed to the
``Domain``-level without needing equal levels of specialization (you'd just
end up needing the exact same structural distinctions in the ``Domain``
hierarchy). So *somewhere* you need a batch splitting implementation that
is both item structure-dependent *and* domain-dependent...the question is
how dynamic you're willing to be about where it comes from. Right now, we
require this actually be defined in the ``_process_batch_data()`` method,
meaning you'll need a specific ``Dataset`` class for each domain you want
to support (e.g., ``MNISTDisk``, ``MNISTList``, ``MNISTNetwork``, etc), or
at least for each domain where "interpreting" a batch could possibly
differ. This is a case where the interface is all that enforces a
distinction: if you've got two domains that can be counted on to yield
batches in the exact same way and can use the same processing, then you
could feasibly provide ``Domain`` objects from either at runtime and have
no issues. We're "structurally blind" to any differentiation beyond the URI
and resource types by design, so two different domain implementations with
the same type signature ``Domain[U, R]`` should be expected to work fine at
runtime (again, so long as they don't also need different batch
processing), but that's not affording us much flexibility, i.e., most of
the time we'll still be defining new dataset classes for each domain.
I initially flagged this as feasible, however, because one could imagine You can't easily abstract the batch-to-item splitting process, i.e.,
accepting a batch processing method upon instantiation rather than ``_process_batch_data()``. A list-based version of the dataset you're trying to
structurally bolting it into the ``Dataset`` definition. This would require define might have an individual item tuple at every index, whereas a disk-based
knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so version might have tuples batched across a few files. This can't reliably be
such a function will always have to be (I, U, R)-dependent. It nevertheless inferred, nor can it be pushed to the ``Domain``-level without needing equal
would take out some of the pain of having to define new dataset classes; levels of specialization (you'd just end up needing the exact same structural
instead, you'd just need to define the batch processing method. I see this distinctions in the ``Domain`` hierarchy). So *somewhere* you need a batch
as a worse alternative to just defining *inside* a safe context like a new splitting implementation that is both item structure-dependent *and*
dataset class: you know the types you have to respect, and you stick that domain-dependent...the question is how dynamic you're willing to be about where
method exactly in a context where it's understood. Freeing this up doesn't it comes from. Right now, we require this actually be defined in the
lighten the burden of processing logic, it just changes *when* it has to be ``_process_batch_data()`` method, meaning you'll need a specific ``Dataset``
provided, and that's not worth much (to me) in this case given the bump in class for each domain you want to support (e.g., ``MNISTDisk``, ``MNISTList``,
complexity. (Taking this to the extreme: you could supply *all* of an ``MNISTNetwork``, etc), or at least for each domain where "interpreting" a
object's methods "dynamically" and glue them together at runtime so long as batch could possibly differ. This is a case where the interface is all that
they all played nice. But wherever you were "laying them out" beforehand is enforces a distinction: if you've got two domains that can be counted on to
exactly the job of a class to begin with, so you don't end up with anything yield batches in the exact same way and can use the same processing, then you
more dynamic. All we're really discussing here is pushing around could feasibly provide ``Domain`` objects from either at runtime and have no
unavoidable complexity inside and outside of the "class walls," and in the issues. We're "structurally blind" to any differentiation beyond the URI and
particular case of ``_process_batch_data()``, it feels much better when resource types by design, so two different domain implementations with the same
it's on the inside.) type signature ``Domain[U, R]`` should be expected to work fine at runtime
(again, so long as they don't also need different batch processing), but that's
not affording us much flexibility, i.e., most of the time we'll still be
defining new dataset classes for each domain.
Holding: I initially flagged this as feasible, however, because one could imagine
@abstractmethod accepting a batch processing method upon instantiation rather than structurally
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]: bolting it into the ``Dataset`` definition. This would require knowledge of the
Get URI groups for each batch. item structure ``I`` as well as the ``Domain[U, R]``, so such a function will
always have to be ``(I, U, R)``-dependent. It nevertheless would take out some
If there's more than one URI per batch (e.g., a data file and a of the pain of having to define new dataset classes; instead, you'd just need
metadata file), zip the URIs such that we have a tuple of URIs per to define the batch processing method. I see this as a worse alternative to
batch. just defining *inside* a safe context like a new dataset class: you know the
types you have to respect, and you stick that method exactly in a context where
Note that this effectively defines the index style over batches in the it's understood. Freeing this up doesn't lighten the burden of processing
attached domain. We get an ``int -> tuple[U, ...]`` map that turns logic, it just changes *when* it has to be provided, and that's not worth much
batch indices into URIs that can be read under the domain. (to me) in this case given the bump in complexity. (Taking this to the extreme:
``get_batch()`` turns an integer index into its corresponding you could supply *all* of an object's methods "dynamically" and glue them
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in together at runtime so long as they all played nice. But wherever you were
the tuple, treating them as providers of batched data. "laying them out" beforehand is exactly the job of a class to begin with, so
``_read_resources()`` passes through to the attached domain logic, you don't end up with anything more dynamic. All we're really discussing here
which, although common, need not supply an explicit iterable of batch is pushing around unavoidable complexity inside and outside of the "class
items: we just access items with ``__getitem__()`` and may ask for walls," and in the particular case of ``_process_batch_data()``, it feels much
``__len__``. So the returned URI group collection (this method) does better when it's on the inside.)
need to be iterable to measure the number of batches, but the batch
objects that are ultimately produced by these URI groups need not be
iterables themselves.
raise NotImplementedError
def _read_resources(
self,
uri_group: tuple[U, ...],
batch_index: int
) -> tuple[R, ...]:
Read batch files at the provided paths.
This method should operate on a single tuple from the list of batch
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
all of the resources for a single batch and returns a tuple of the same
size with their contents.
Note: the dependence on a batch index is mostly here to make
multi-dataset composition easier later. In-dataset, you don't need to
know the batch index to to simply process URIs, but across datasets you
need it to find out the origin of the batch (and process those URIs
accordingly).
return tuple(self.domain.read(uri) for uri in uri_group)
# pulling the type variable out of the inline generic b/c `ty` has trouble
# understanding bound type variables in subclasses (specifically with Self@)
T = TypeVar("T", bound=NamedTuple)
class NamedTupleDataset[I](Dataset):
def __init__(self, data_list: list[I]) -> None:
self.data_list = data_list
def __len__(self) -> int:
return len(self.data_list)
def __getitem__(self, index: int) -> I:
return self.data_list[index]
""" """
import math import math
@@ -154,40 +95,77 @@ class BatchedDataset[U, R, I](Dataset):
The class is generic over a URI type ``U``, a resource type ``R`` (both of The class is generic over a URI type ``U``, a resource type ``R`` (both of
which are used to concretize a domain ``Domain[U, R]``), and an item type which are used to concretize a domain ``Domain[U, R]``), and an item type
``T`` (which has a ``tuple`` upper bound). ``I``.
Pipeline overview: **Batch and item processing flow**
``` .. code-block:: text
Domain -> [U] (get _batch_uris)
U -> R (domain access ; Rs provide batches)
R -> [I] (cache here ; _process_batch_data to use load_transform)
[I] -> I (human item obj ; _get_item)
I -> **P (final packed item ; __getitem__ to use transform)
```
Note^1: as far as positioning, this class is meant to play nice with Domain -> [U] :: self._batch_uris = list(domain)
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
value add for this over the ``torch.Dataset`` base is almost entirely in
the logic it implements to map out of *batched resources* that are holding
data, and flattening it out into typical dataset items. There are also some
QoL items when it comes to splitting and balancing samples.
Note^2: even though ``Domains`` implement iterators over their URIs, this Grab all URIs from Domain iterators. This is made concrete early to
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk allow for Dataset sizing, and we need a Sequence representation to
over the resources that provide data, but we don't necessarily presuppose map integer batch indices into Domains, i.e., when getting the
an ordered walk over samples within batches. Point being: corresponding URI:
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
superclass, even when we're working around iterable ``Domains``.
Note^3: transforms are expected to operate on ``I``-items and produce batch_uri = self._batch_uris[batch_index]
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
other intermediate representation, nor should they map from ``I`` to We let Domains implement iterators over their URIs, but explicitly
something else. Point being: the dataset definition should be able to map exhaust when initializing Datasets.
resources ``R`` to ``I`` without a transform: that much should be baked
into the class definition. If you find you're expecting the transform to do U -> R :: batch_data = self.domain[batch_uri]
that for you, you should consider pulling in some common structure across
the allowed transforms and make it a fixed part of the class. Retrieve resource from domain. Resources are viewed as batched
data, even if only wrapping single items (happens in trivial
settings).
R -> [I] :: self._process_batch_data(batch_data, batch_index)
Possibly domain-specific batch processing of resource data into
explicit Sequence-like structures of items, each of which is
subject to the provided pre_transform. Processed batches at this
stage are cached (if enabled).
[I] -> I :: self.get_batch(batch_index)[index_in_batch]
Select individual items from batches in _get_item. At this stage,
items are in intermediate states and pulled from the cached
batches.
I -> I :: self._process_item_data(item_data, index)
Produce final items with __getitem__, getting intermediate items
via _get_item and applying the provided post_transform.
.. note::
As far as positioning, this class is meant to play nice with PyTorch
DataLoaders, hence the inheritance from ``torch.Dataset``. The value
add for this over the ``torch.Dataset`` base is almost entirely in the
logic it implements to map out of *batched resources* that are holding
data, and flattening it out into typical dataset items. There are also
some QoL features when it comes to splitting and balancing samples.
.. note::
Even though ``Domains`` implement iterators over their URIs, this
doesn't imply a ``BatchedDataset`` is iterable. This just means we can
walk over the resources that provide data, but we don't necessarily
presuppose an ordered walk over samples within batches. Point being:
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
superclass, even when we're working around iterable ``Domains``.
.. note::
Transforms are expected to operate on ``I``-items and produce
``I``-items. They shouldn't be the "introducers" of ``I`` types from
some other intermediate representation, nor should they map from ``I``
to something else. Point being: the dataset definition should be able
to map resources ``R`` to ``I`` without a transform: that much should
be baked into the class definition. If you find you're expecting the
transform to do that for you, you should consider pulling in some
common structure across the allowed transforms and make it a fixed part
of the class.
""" """
def __init__( def __init__(
@@ -201,6 +179,7 @@ class BatchedDataset[U, R, I](Dataset):
) -> None: ) -> None:
""" """
Parameters: Parameters:
domain: ``Domain`` object providing access to batched data
pre_transform: transform to apply over items during loading (in pre_transform: transform to apply over items during loading (in
``_process_batch_data()``), i.e., *before* going into ``_process_batch_data()``), i.e., *before* going into
persistent storage persistent storage
@@ -210,6 +189,7 @@ class BatchedDataset[U, R, I](Dataset):
batch_cache_limit: the max number of max batches to cache at any batch_cache_limit: the max number of max batches to cache at any
one time one time
preload: whether to load all data into memory during instantiation preload: whether to load all data into memory during instantiation
num_workers: number of workers to use when preloading data
""" """
self.domain = domain self.domain = domain
@@ -249,6 +229,9 @@ class BatchedDataset[U, R, I](Dataset):
The behavior of this method can vary depending on what we know about The behavior of this method can vary depending on what we know about
batch sizes, and should therefore be implemented by inheriting classes. batch sizes, and should therefore be implemented by inheriting classes.
Parameters:
item_index: index of item
Returns: Returns:
batch_index: int batch_index: int
index_in_batch: int index_in_batch: int
@@ -292,6 +275,10 @@ class BatchedDataset[U, R, I](Dataset):
place to use a provided ``post_transform``; items are pulled from the place to use a provided ``post_transform``; items are pulled from the
cache (if enabled) and processed before being returned as the final cache (if enabled) and processed before being returned as the final
tuple outputs (so this processing is not persistent). tuple outputs (so this processing is not persistent).
Parameters:
item_data: item data
item_index: index of item
""" """
raise NotImplementedError raise NotImplementedError
@@ -307,6 +294,9 @@ class BatchedDataset[U, R, I](Dataset):
Note that return values from `__getitem__()` are "cleaned up" versions Note that return values from `__getitem__()` are "cleaned up" versions
of this representation, with minimal info needed for training. of this representation, with minimal info needed for training.
Parameters:
item_index: index of item
""" """
if item_index >= len(self): if item_index >= len(self):
@@ -341,10 +331,13 @@ class BatchedDataset[U, R, I](Dataset):
they're always connected, and nothing would notice if you waited they're always connected, and nothing would notice if you waited
between steps. The only way this could matter is if you split the between steps. The only way this could matter is if you split the
resource reading and batch processing steps across methods, but when it resource reading and batch processing steps across methods, but when it
actually comes to accessing/caching the batch, you'd have to expand actually comes to accessing/caching the batch, you'd have to expand any
any delayed reads here. There's no way around needing to see all batch delayed reads here. There's no way around needing to see all batch data
data at once here, and we don't want to make that ambiguous: ``list`` at once here, and we don't want to make that ambiguous: ``list`` output
output type it is. type it is.
Parameters:
batch_index: index of batch
""" """
logger.debug("Batch cache miss, reading from root...") logger.debug("Batch cache miss, reading from root...")
@@ -364,6 +357,9 @@ class BatchedDataset[U, R, I](Dataset):
Can be useful when dynamically pulling data (as it's requested) isn't Can be useful when dynamically pulling data (as it's requested) isn't
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
continually remove previous batches as they're loaded. continually remove previous batches as they're loaded.
Parameters:
num_workers: number of parallel workers to use for data loading
""" """
assert self.batch_cache_limit is None, "Preloading under cache limit" assert self.batch_cache_limit is None, "Preloading under cache limit"
@@ -396,36 +392,46 @@ class BatchedDataset[U, R, I](Dataset):
""" """
Split dataset into fractional pieces by data attribute. Split dataset into fractional pieces by data attribute.
If `by_attr` is None, recovers typical fractional splitting of dataset If ``by_attr`` is None, recovers typical fractional splitting of
items, partitioning by size. Using None anywhere will index each item dataset items, partitioning by size. Using None anywhere will index
into its own bucket, i.e., by its index. For instance, each item into its own bucket, i.e., by its index. For instance:
- by_attr=["color"] -> {("red", 1), ("red", 2)}, - Splits on the attribute such that each subset contains entire strata
{("blue", 1), ("blue", 2)} of the attribute. "Homogeneity within clusters:"
Splits on the attribute such that each subset contains entire strata .. code-block::
of the attribute. "Homogeneity within clusters"
- `by_attr=["color", None]` -> {("red", 1), ("blue", 1)}, by_attr=["color"] -> {("red", 1), ("red", 2)},
{("red", 2), ("blue", 2)} {("blue", 1), ("blue", 2)}
Stratifies by attribute and then splits "by index" within, uniformly - Stratifies by attribute and then splits "by index" within, uniformly
grabbing samples across strata to form new clusters. "Homogeneity grabbing samples across strata to form new clusters. "Homogeneity
across clusters" across clusters"
.. code-block::
by_attr=["color", None] -> {("red", 1), ("blue", 1)},
{("red", 2), ("blue", 2)}
Note that the final list of Subsets returned are built from shallow Note that the final list of Subsets returned are built from shallow
copies of the underlying dataset (i.e., `self`) to allow manual copies of the underlying dataset (i.e., ``self``) to allow manual
intervention with dataset attributes (e.g., setting the splits to have intervention with dataset attributes (e.g., setting the splits to have
different `transform`s). This is subject to possibly unexpected different ``transforms``). This is subject to possibly unexpected
behavior if re-caching data or you need a true copy of all data in behavior if re-caching data or you need a true copy of all data in
memory, but should otherwise leave most interactions unchanged. memory, but should otherwise leave most interactions unchanged.
Parameters: Parameters:
frac: split fractions for datasets
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
shuffle_strata: shuffle the strata order before split is drawn. We shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a dataloader-level shuffle operation parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting will only change the order of the indices in the resulting
splits; only a shuffle of items inside the strata can change splits; only a shuffle of the strata order can change the
the actual content of the splits themselves. actual content of the splits themselves.
""" """
if by_attr == []: if by_attr == []:
@@ -534,6 +540,32 @@ class BatchedDataset[U, R, I](Dataset):
split_max_sizes: list[int] | None = None, split_max_sizes: list[int] | None = None,
shuffle_strata: bool = True, shuffle_strata: bool = True,
) -> None: ) -> None:
"""
Balance the distribution of provided attributes over dataset items.
This method sets the indices over the dataset according to the result
of the rebalancing. The indices are produced by the recursive
``_balance()`` method, which is necessarily separate due to the need
for a contained recursive approach that doesn't change the underlying
dataset during execution.
Parameters:
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
split_min_sizes: minimum allowed sizes of splits. Must have the
same length as ``by_attr``.
split_max_sizes: maximum allowed sizes of splits. Must have the
same length as ``by_attr``.
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of the strata order can change the
actual content of the splits themselves.
"""
self.indices = self._balance( self.indices = self._balance(
dataset, dataset,
by_attr, by_attr,
@@ -551,9 +583,29 @@ class BatchedDataset[U, R, I](Dataset):
shuffle_strata: bool = True, shuffle_strata: bool = True,
) -> list[int]: ) -> list[int]:
""" """
Note: behavior is a little odd for nested behavior; not exactly Recursive balancing of items by attribute.
perfectly uniform throughout. This is a little difficult: you can't
exactly know ahead of time the size of the subgroups across splits .. note::
Behavior is a little odd for nested behavior; not exactly perfectly
uniform throughout. This is a little difficult: you can't exactly
know ahead of time the size of the subgroups across splits
Parameters:
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
split_min_sizes: minimum allowed sizes of splits. Must have the
same length as ``by_attr``.
split_max_sizes: maximum allowed sizes of splits. Must have the
same length as ``by_attr``.
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of the strata order can change the
actual content of the splits themselves.
""" """
if by_attr == []: if by_attr == []:
@@ -643,6 +695,9 @@ class BatchedDataset[U, R, I](Dataset):
dataset. The underlying data remain the same, but when indices get set, dataset. The underlying data remain the same, but when indices get set,
you're effectively applying a mask over any existing indices, always you're effectively applying a mask over any existing indices, always
operating *relative* to the existing mask. operating *relative* to the existing mask.
Parameters:
indices: list of indices to set
""" """
# manually set new size # manually set new size
@@ -670,6 +725,13 @@ class BatchedDataset[U, R, I](Dataset):
return self._dataset_len return self._dataset_len
def __getitem__(self, index: int) -> I: def __getitem__(self, index: int) -> I:
"""
Get the dataset item at the specified index.
Parameters:
index: index of item to retrieve
"""
item_data = self._get_item(index) item_data = self._get_item(index)
index = self.indices[index] index = self.indices[index]
@@ -691,9 +753,10 @@ class CompositeBatchedDataset[U, R, I](BatchedDataset[U, R, I]):
""" """
Dataset class for wrapping individual datasets. Dataset class for wrapping individual datasets.
Note: because this remains a valid ``BatchedDataset``, we re-thread the .. note::
generic type variables through the set of composed datasets. That is, they Because this remains a valid ``BatchedDataset``, we re-thread the
must have a common domain type ``Domain[U, R]``. generic type variables through the set of composed datasets. That is,
they must have a common domain type ``Domain[U, R]``.
""" """
def __init__( def __init__(
@@ -878,7 +941,7 @@ class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]): class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
""" """
Batched dataset where batches have arbitrary size. Batched dataset where batches may have arbitrary size.
Methods left for inheriting classes: Methods left for inheriting classes:

View File

@@ -12,39 +12,38 @@ class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
""" """
The following line is to satisfy the type checker, which The following line is to satisfy the type checker, which
1. Can't recognize an appropriately re-typed constructor arg like 1. Can't recognize an appropriately re-typed constructor arg like::
def __init__(
self,
domain: DiskDomain,
...
): ...
This *does* match the parent generic for the U=Path, R=bytes context def __init__(
self,
domain: DiskDomain,
...
): ...
def __init__( This *does* match the parent generic for the ``U=Path``, ``R=bytes``
self, context::
domain: Domain[U, R],
... def __init__(
): ... self,
domain: Domain[U, R],
...
): ...
but the type checker doesn't see this. but the type checker doesn't see this.
2. "Lifted" type variables out of generics can't be used as upper bounds, 2. "Lifted" type variables out of generics can't be used as upper bounds,
at least not without throwing type checker warnings (thanks to PEP695). at least not without throwing type checker warnings (thanks to PEP695).
So I'm not allowed to have So I'm not allowed to have::
``` class BatchedDataset[U, R, D: Domain[U, R]]:
class BatchedDataset[U, R, D: Domain[U, R]]: ...
...
```
which could bring appropriately dynamic typing for ``Domain``s, but is which could bring appropriately dynamic typing for ``Domains``, but is
not a sufficiently concrete upper bound. not a sufficiently concrete upper bound.
So: we settle for a class-level type declaration, which despite not being So: we settle for a class-level type declaration, which despite not being
technically appropriately scoped, it's not harming anything and satisfies technically appropriately scoped, it's not harming anything and satisfies
``ty`` type checks downstream (e.g., when we access ``DiskDomain.root``. ``ty`` type checks downstream (e.g., when we access ``DiskDomain.root``).
""" """
domain: DiskDomain domain: DiskDomain

View File

@@ -1,24 +1,5 @@
""" """
Defines a knowledge domain. Wraps a Dataset / Simulator / Knowledge Generic URI-resource mapping structure
Downstream exploration might include
- Calibrating Simulator / Knowledge with a Dataset
- Amending Dataset with Simulator / Knowledge
- Positioning Knowledge within Simulator context
* Where to replace Simulator subsystem with Knowledge?
Other variations:
- Multi-fidelity simulators
- Multi-scale models
- Multi-system
- Incomplete knowledge / divergence among sources
Questions:
- Should Simulator / Knowledge be unified as one (e.g., "Expert")
""" """
from collections.abc import Mapping, Iterator, Sequence from collections.abc import Mapping, Iterator, Sequence

View File

@@ -1,4 +1,6 @@
""" """
Base class for trainable models
Development note Development note
I'd rather lay out bare args and kwargs in the estimator methods, but the I'd rather lay out bare args and kwargs in the estimator methods, but the
@@ -24,7 +26,7 @@ from torch import nn, Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from trainlib.util.type import OptimizerKwargs from trainlib.utils.type import OptimizerKwargs
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)

View File

@@ -107,8 +107,14 @@ class LSTM[K: RNNKwargs](Estimator[K]):
with torch.no_grad(): with torch.no_grad():
loss = next(self.loss(**kwargs)).item() loss = next(self.loss(**kwargs)).item()
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
mae = F.l1_loss(predictions, labels).item()
return { return {
"loss": loss, "loss": loss,
"mse": loss,
"mae": mae,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }
@@ -291,7 +297,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
logger.info(f"| > {self.output_dim=}") logger.info(f"| > {self.output_dim=}")
class ConvRNN[K: RNNKwargs](Estimator[K]): class ConvGRU[K: RNNKwargs](Estimator[K]):
""" """
Base recurrent convolutional architecture. Base recurrent convolutional architecture.
@@ -441,11 +447,18 @@ class ConvRNN[K: RNNKwargs](Estimator[K]):
with torch.no_grad(): with torch.no_grad():
loss = next(self.loss(**kwargs)).item() loss = next(self.loss(**kwargs)).item()
predictions = self(**kwargs)[0].squeeze(-1)
labels = kwargs["labels"]
mae = F.l1_loss(predictions, labels).item()
return { return {
"loss": loss, "loss": loss,
"mse": loss,
"mae": mae,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }
def optimizers( def optimizers(
self, self,
**kwargs: Unpack[OptimizerKwargs], **kwargs: Unpack[OptimizerKwargs],

View File

@@ -1,10 +1,7 @@
import logging import logging
from collections.abc import Generator
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from torch.optim import Optimizer
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)

View File

@@ -1,3 +1,7 @@
"""
Core interface for training ``Estimators`` with ``Datasets``
"""
import os import os
import time import time
import logging import logging
@@ -11,6 +15,7 @@ from collections.abc import Callable
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from torch import cuda, Tensor from torch import cuda, Tensor
from torch.optim import Optimizer
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@@ -30,7 +35,15 @@ logger: logging.Logger = logging.getLogger(__name__)
class Trainer[I, K: EstimatorKwargs]: class Trainer[I, K: EstimatorKwargs]:
""" """
Training interface for updating ``Estimators`` with ``Datasets``. Training interface for optimizing parameters of ``Estimators`` with
``Datasets``.
This class is generic to a dataset item type ``I`` and an estimator kwarg
type ``K``. These are the two primary components ``Trainer`` objects need
to coordinate: they ultimately rely on a provided map to ensure data items
(type ``I``) from a dataset are appropriately routed as inputs to key
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
of type ``K``.
""" """
def __init__( def __init__(
@@ -42,8 +55,10 @@ class Trainer[I, K: EstimatorKwargs]:
) -> None: ) -> None:
""" """
Parameters: Parameters:
estimator: `Estimator` model object estimator: ``Estimator`` model object
device: device on which to carry out training device: device on which to carry out training
chkpt_dir: directory to write model checkpoints
tblog_dir: directory to write TensorBoard logs
""" """
self.device: str self.device: str
@@ -86,7 +101,7 @@ class Trainer[I, K: EstimatorKwargs]:
def reset(self) -> None: def reset(self) -> None:
""" """
Set base tracking parameters. Set initial tracking parameters for the primary training loop.
""" """
self._step: int = 0 self._step: int = 0
@@ -98,6 +113,151 @@ class Trainer[I, K: EstimatorKwargs]:
self._stagnant_epochs = 0 self._stagnant_epochs = 0
self._best_model_state_dict: dict[str, Any] = {} self._best_model_state_dict: dict[str, Any] = {}
def _train_epoch(
self,
train_loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
optimizers: tuple[Optimizer, ...],
writer: SummaryWriter,
max_grad_norm: float | None = None,
) -> list[float]:
"""
Train the estimator for a single epoch.
"""
train_loss_sums = []
self.estimator.train()
with tqdm(train_loader, unit="batch") as train_epoch:
for i, batch_data in enumerate(train_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
inputs = est_kwargs["inputs"]
# one-time logging
if self._step == 0:
writer.add_graph(
ModelWrapper(self.estimator),
est_kwargs
)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=False,
**est_kwargs
)
train_losses = self.estimator.loss(**est_kwargs)
train_loss_items = []
for o_idx, optimizer in enumerate(optimizers):
optimizer.zero_grad()
train_loss = next(train_losses)
if len(train_loss_sums) <= o_idx:
train_loss_sums.append(0.0)
train_loss_item = train_loss.item()
train_loss_sums[o_idx] += train_loss_item
train_loss_items.append(train_loss_item)
train_loss.backward()
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
clip_grad_norm_(
self._get_optimizer_parameters(optimizer),
max_norm=max_grad_norm
)
optimizer.step()
self._step += len(inputs)
for train_loss_item, train_loss_sum in zip(
train_loss_items,
train_loss_sums,
strict=True,
):
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
self._add_summary_item("train_loss", train_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(
f"train_{metric_name}",
metric_value
)
self.estimator.epoch_step()
for li, train_loss_sum in enumerate(train_loss_sums):
self._add_summary_item(
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
)
return train_loss_sums
def _val_epoch(
self,
val_loader: DataLoader,
batch_estimator_map: Callable[[I, Self], K],
optimizers: tuple[Optimizer, ...],
writer: SummaryWriter,
) -> list[float]:
"""
Perform and record validation scores for a single epoch.
"""
val_loss_sums = []
self.estimator.eval()
with tqdm(val_loader, unit="batch") as val_epoch:
for i, batch_data in enumerate(val_epoch):
est_kwargs = batch_estimator_map(batch_data, self)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=True,
**est_kwargs
)
val_losses = self.estimator.loss(**est_kwargs)
val_loss_items = []
for o_idx in range(len(optimizers)):
val_loss = next(val_losses)
if len(val_loss_sums) <= o_idx:
val_loss_sums.append(0.0)
val_loss_item = val_loss.item()
val_loss_sums[o_idx] += val_loss_item
val_loss_items.append(val_loss_item)
for val_loss_item, val_loss_sum in zip(
val_loss_items,
val_loss_sums,
strict=True,
):
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
self._add_summary_item("val_loss", val_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"val_{metric_name}", metric_value)
for li, val_loss_sum in enumerate(val_loss_sums):
self._add_summary_item(
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
)
# convergence of multiple losses may be ambiguous
self._val_loss = sum(val_loss_sums) / len(val_loader)
return val_loss_sums
def train( def train(
self, self,
dataset: BatchedDataset[..., ..., I], dataset: BatchedDataset[..., ..., I],
@@ -122,44 +282,54 @@ class Trainer[I, K: EstimatorKwargs]:
""" """
Note: this method attempts to implement a general scheme for passing Note: this method attempts to implement a general scheme for passing
needed items to the estimator's loss function from the dataloader. The needed items to the estimator's loss function from the dataloader. The
abstract `Estimator` base only requires the model output be provided abstract ``Estimator`` base only requires the model output be provided
for any given loss calculation, but concrete estimators will often for any given loss calculation, but concrete estimators will often
require additional arguments (e.g., labels or length masks, as require additional arguments (e.g., labels or length masks, as is the
is the case with sequential models). In any case, this method defers case with sequential models). In any case, this method defers any
any further logic to the `loss` method of the underlying estimator, so further logic to the ``loss`` method of the underlying estimator, so
one should take care to synchronize the sample structure with `dataset` one should take care to synchronize the sample structure with `dataset`
to match that expected by `self.estimator.loss(...)`. to match that expected by ``self.estimator.loss(...)``.
.. admonition:: On ``batch_estimator_map``
On batch_estimator_map: Dataloader collate functions are responsible for mapping a
collection of items into an item of collections, roughly speaking.
If items are tuples of tensors,
Dataloader collate functions are responsible for mapping a collection .. code-block:: text
of items into an item of collections, roughly speaking. If items are
tuples of tensors,
[ [
( [1, 1], [1, 1] ), ( [1, 1], [1, 1] ),
( [2, 2], [2, 2] ), ( [2, 2], [2, 2] ),
( [3, 3], [3, 3] ), ( [3, 3], [3, 3] ),
] ]
the collate function maps back into the item skeleton, producing a the collate function maps back into the item skeleton, producing a
single tuple of (stacked) tensors single tuple of (stacked) tensors
.. code-block:: text
( [[1, 1], ( [[1, 1],
[2, 2], [2, 2],
[3, 3]], [3, 3]],
[[1, 1], [[1, 1],
[2, 2], [2, 2],
[3, 3]] ) [3, 3]] )
This function should map from batches (which should be *item shaped*, This function should map from batches (which should be *item
i.e., have an `I` skeleton, even if stacked items may be different on shaped*, i.e., have an ``I`` skeleton, even if stacked items may be
the inside) into estimator keyword arguments (type `K`). different on the inside) into estimator keyword arguments (type
``K``).
Parameters: Parameters:
dataset: dataset to train the estimator
batch_estimator_map: function mapping from batch data to expected
estimator kwargs
lr: learning rate (default: 1e-3) lr: learning rate (default: 1e-3)
eps: adam EPS (default: 1e-8) eps: adam EPS (default: 1e-8)
max_grad_norm: upper bound to use when clipping gradients. If left
as ``None``, no gradient clipping is performed.
max_epochs: maximum number of training epochs max_epochs: maximum number of training epochs
stop_after_epochs: number of epochs with stagnant validation losses stop_after_epochs: number of epochs with stagnant validation losses
to allow before early stopping. If training stops earlier, the to allow before early stopping. If training stops earlier, the
@@ -212,132 +382,39 @@ class Trainer[I, K: EstimatorKwargs]:
while self._epoch <= max_epochs and not self._converged( while self._epoch <= max_epochs and not self._converged(
self._epoch, stop_after_epochs self._epoch, stop_after_epochs
): ):
print(f"Training epoch {self._epoch}/{max_epochs}...") train_frac = f"{self._epoch}/{max_epochs}"
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...") stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
print(f"Training epoch {train_frac}...")
print(f"Stagnant epochs {stag_frac}...")
epoch_start_time = time.time() epoch_start_time = time.time()
train_loss_sums = [] self._train_epoch(
self.estimator.train() train_loader,
with tqdm(train_loader, unit="batch") as train_epoch: batch_estimator_map,
for i, batch_data in enumerate(train_epoch): optimizers,
est_kwargs = batch_estimator_map(batch_data, self) writer,
inputs = est_kwargs["inputs"] max_grad_norm
)
# one-time logging
if self._step == 0:
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
# once-per-epoch logging
if i == 0:
self.estimator.epoch_write(
writer,
step=self._step,
val=False,
**est_kwargs
)
train_losses = self.estimator.loss(**est_kwargs)
train_loss_items = []
for o_idx, optimizer in enumerate(optimizers):
optimizer.zero_grad()
train_loss = next(train_losses)
if len(train_loss_sums) <= o_idx:
train_loss_sums.append(0.0)
train_loss_item = train_loss.item()
train_loss_sums[o_idx] += train_loss_item
train_loss_items.append(train_loss_item)
train_loss.backward()
# clip gradients for optimizer's parameters
if max_grad_norm is not None:
opt_params = self._get_optimizer_parameters(optimizer)
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
optimizer.step()
self._step += len(inputs)
for train_loss_item, train_loss_sum in zip(
train_loss_items,
train_loss_sums,
strict=True,
):
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
self._add_summary_item("train_loss", train_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"train_{metric_name}", metric_value)
self.estimator.epoch_step()
for li, train_loss_sum in enumerate(train_loss_sums):
self._add_summary_item(
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
)
if val_frac > 0: if val_frac > 0:
val_loss_sums = [] self._val_epoch(
self.estimator.eval() val_loader,
with tqdm(val_loader, unit="batch") as val_epoch: batch_estimator_map,
for i, batch_data in enumerate(val_epoch): optimizers,
est_kwargs = batch_estimator_map(batch_data, self) writer,
inputs = est_kwargs["inputs"] )
# once-per-epoch logging self._add_summary_item(
if i == 0: "epoch_time_sec",
self.estimator.epoch_write( time.time() - epoch_start_time
writer, )
step=self._step,
val=True,
**est_kwargs
)
val_losses = self.estimator.loss(**est_kwargs)
val_loss_items = []
for o_idx in range(len(optimizers)):
val_loss = next(val_losses)
if len(val_loss_sums) <= o_idx:
val_loss_sums.append(0.0)
val_loss_item = val_loss.item()
val_loss_sums[o_idx] += val_loss_item
val_loss_items.append(val_loss_item)
for val_loss_item, val_loss_sum in zip(
val_loss_items,
val_loss_sums,
strict=True,
):
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
self._add_summary_item("val_loss", val_loss_item)
estimator_metrics = self.estimator.metrics(**est_kwargs)
for metric_name, metric_value in estimator_metrics.items():
self._add_summary_item(f"val_{metric_name}", metric_value)
for li, val_loss_sum in enumerate(val_loss_sums):
self._add_summary_item(
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
)
# convergence of multiple losses may be ambiguous
self._val_loss = sum(val_loss_sums) / len(val_loader)
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
if self._epoch % summarize_every == 0: if self._epoch % summarize_every == 0:
self._summarize(writer, self._epoch) self._summarize(writer, self._epoch)
# save checkpoint # save checkpoint
if self._epoch % chkpt_every == 0: if self._epoch % chkpt_every == 0:
self.save_model( self.save_model(self._epoch, self.chkpt_dir, dir_prefix)
self._epoch, self.chkpt_dir, dir_prefix
)
self._epoch += 1 self._epoch += 1
@@ -431,7 +508,7 @@ class Trainer[I, K: EstimatorKwargs]:
def _summarize(self, writer: SummaryWriter, epoch: int) -> None: def _summarize(self, writer: SummaryWriter, epoch: int) -> None:
""" """
Flush the training summary to the TB summary writer. Flush the training summary to the TensorBoard summary writer.
""" """
summary_values = defaultdict(list) summary_values = defaultdict(list)
@@ -485,17 +562,18 @@ class Trainer[I, K: EstimatorKwargs]:
chkpt_dir.mkdir(parents=True, exist_ok=True) chkpt_dir.mkdir(parents=True, exist_ok=True)
chkpt_path.write_bytes(model_buff.getvalue()) chkpt_path.write_bytes(model_buff.getvalue())
def load_model( def load_model(self, epoch: int, chkpt_dir: str) -> None:
self,
epoch: int,
chkpt_dir: str,
) -> None:
""" """
Load a model checkpoint from a given epoch. Load a model checkpoint from a given epoch.
Note that this assumes the model was saved via `Trainer.save_model()`, Note that this assumes the model was saved via
and the estimator provided to this `Trainer` instance matches the ``Trainer.save_model()``, and the estimator provided to this
architecture of the checkpoint model being loaded. ``Trainer`` instance matches the architecture of the checkpoint model
being loaded.
Parameters:
epoch: epoch of saved model
chkpt_dir:
""" """
model_class = self.estimator.__class__.__name__ model_class = self.estimator.__class__.__name__

View File

@@ -1,3 +1,7 @@
"""
Transform base for dataset records
"""
class Transform[I]: class Transform[I]:
""" """
Dataset transform base class. Dataset transform base class.
@@ -8,4 +12,14 @@ class Transform[I]:
""" """
def __call__(self, item: I) -> I: def __call__(self, item: I) -> I:
"""
Apply transform to item.
Parameters:
item: item object to transform
Returns:
transformed item (same type ``I`` as input)
"""
raise NotImplementedError raise NotImplementedError

View File

@@ -0,0 +1,46 @@
text.usetex : False
mathtext.default : regular
font.family : sans-serif
font.sans-serif : DejaVu Sans
font.serif : DejaVu Serif
font.cursive : DejaVu Sans
mathtext.fontset : dejavuserif
font.size : 9
figure.titlesize : 9
legend.fontsize : 9
axes.titlesize : 9
axes.labelsize : 9
xtick.labelsize : 9
ytick.labelsize : 9
#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb'])
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
image.interpolation : nearest
image.resample : False
image.composite_image : True
axes.spines.left : True
axes.spines.bottom : True
axes.spines.top : False
axes.spines.right : False
axes.linewidth : 1
xtick.major.width : 1
xtick.minor.width : 1
ytick.major.width : 1
ytick.minor.width : 1
lines.linewidth : 1
lines.markersize : 1
savefig.dpi : 300
savefig.format : svg
savefig.bbox : tight
savefig.pad_inches : 0.1
svg.image_inline : True
svg.fonttype : none
legend.frameon : False

38
trainlib/utils/plot.py Normal file
View File

@@ -0,0 +1,38 @@
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
FILE = Path(__file__).parent.absolute()
class use_style:
def __init__(
self,
style: list[str] | None = None,
**kwargs: str,
) -> None:
super().__init__()
if style is None:
style = [str(Path(FILE, "custom.mplstyle"))]
self.style = style + [kwargs]
self.previous_style = {}
def __enter__(self) -> None:
self.previous_style = mpl.rcParams.copy()
if self.style is not None:
plt.style.use(self.style)
def __exit__(self, *args: str, **kwargs: str) -> None:
mpl.rcParams.update(self.previous_style)
def set_style(
style: list[str] | None = None,
**kwargs: str,
) -> None:
if style is None:
style = [str(Path(FILE, "custom.mplstyle"))]
plt.style.use(style + [kwargs])

21
trainlib/utils/session.py Normal file
View File

@@ -0,0 +1,21 @@
import random
import numpy as np
import torch
from torch import Tensor
def seed_all_backends(seed: int | Tensor | None = None) -> None:
"""Sets all python, numpy and pytorch seeds."""
if seed is None:
seed = int(torch.randint(1000000, size=(1,)))
else:
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

100
uv.lock generated
View File

@@ -248,9 +248,13 @@ dependencies = [
{ name = "cuda-pathfinder" }, { name = "cuda-pathfinder" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/8b/b4b2d1c7775fa403b64333e720cfcfccef8dcb9cdeb99947061ca5a77628/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf8bfaedc238f3b115d957d1fd6562b7e8435ba57f6d0e2f87d0e7149ccb2da5", size = 11570071, upload-time = "2025-10-21T14:51:47.472Z" },
{ url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" },
{ url = "https://files.pythonhosted.org/packages/ec/07/6aff13bc1e977e35aaa6b22f52b172e2890c608c6db22438cf7ed2bf43a6/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3adf4958dcf68ae7801a59b73fb00a8b37f8d0595060d66ceae111b1002de38d", size = 11566797, upload-time = "2025-10-21T14:51:54.581Z" },
{ url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" },
{ url = "https://files.pythonhosted.org/packages/1e/b5/96a6696e20c4ffd2b327f54c7d0fde2259bdb998d045c25d5dedbbe30290/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f53a7f453d4b2643d8663d036bafe29b5ba89eb904c133180f295df6dc151e5", size = 11624530, upload-time = "2025-10-21T14:52:01.539Z" },
{ url = "https://files.pythonhosted.org/packages/d1/af/6dfd8f2ed90b1d4719bc053ff8940e494640fe4212dc3dd72f383e4992da/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686", size = 11922703, upload-time = "2025-10-21T14:52:03.585Z" }, { url = "https://files.pythonhosted.org/packages/d1/af/6dfd8f2ed90b1d4719bc053ff8940e494640fe4212dc3dd72f383e4992da/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686", size = 11922703, upload-time = "2025-10-21T14:52:03.585Z" },
{ url = "https://files.pythonhosted.org/packages/39/73/d2fc40c043bac699c3880bf88d3cebe9d88410cd043795382826c93a89f0/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:20f2699d61d724de3eb3f3369d57e2b245f93085cab44fd37c3bea036cea1a6f", size = 11565056, upload-time = "2025-10-21T14:52:08.338Z" },
{ url = "https://files.pythonhosted.org/packages/6c/19/90ac264acc00f6df8a49378eedec9fd2db3061bf9263bf9f39fd3d8377c3/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee", size = 11924658, upload-time = "2025-10-21T14:52:10.411Z" }, { url = "https://files.pythonhosted.org/packages/6c/19/90ac264acc00f6df8a49378eedec9fd2db3061bf9263bf9f39fd3d8377c3/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee", size = 11924658, upload-time = "2025-10-21T14:52:10.411Z" },
] ]
@@ -861,6 +865,7 @@ name = "nvidia-cublas-cu12"
version = "12.8.4.1" version = "12.8.4.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" },
{ url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" },
] ]
@@ -869,6 +874,7 @@ name = "nvidia-cuda-cupti-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" },
{ url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" },
] ]
@@ -878,6 +884,7 @@ version = "12.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" },
{ url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" },
] ]
[[package]] [[package]]
@@ -885,6 +892,7 @@ name = "nvidia-cuda-runtime-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" },
{ url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" },
] ]
@@ -896,6 +904,7 @@ dependencies = [
{ name = "nvidia-cublas-cu12" }, { name = "nvidia-cublas-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
] ]
@@ -907,6 +916,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
] ]
@@ -916,6 +926,7 @@ version = "1.13.1.3"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" },
{ url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" },
] ]
[[package]] [[package]]
@@ -923,6 +934,7 @@ name = "nvidia-curand-cu12"
version = "10.3.9.90" version = "10.3.9.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" },
{ url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" },
] ]
@@ -936,6 +948,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
] ]
@@ -947,6 +960,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
] ]
@@ -955,6 +969,7 @@ name = "nvidia-cusparselt-cu12"
version = "0.7.1" version = "0.7.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" },
{ url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" },
] ]
@@ -963,6 +978,7 @@ name = "nvidia-nccl-cu12"
version = "2.27.5" version = "2.27.5"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/1c/857979db0ef194ca5e21478a0612bcdbbe59458d7694361882279947b349/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:31432ad4d1fb1004eb0c56203dc9bc2178a1ba69d1d9e02d64a6938ab5e40e7a", size = 322400625, upload-time = "2025-06-26T04:11:04.496Z" },
{ url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" },
] ]
@@ -972,6 +988,7 @@ version = "12.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" },
{ url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" },
] ]
[[package]] [[package]]
@@ -979,6 +996,7 @@ name = "nvidia-nvshmem-cu12"
version = "3.4.5" version = "3.4.5"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/1d/6a/03aa43cc9bd3ad91553a88b5f6fb25ed6a3752ae86ce2180221962bc2aa5/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15", size = 138936938, upload-time = "2025-09-06T00:32:05.589Z" },
{ url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" },
] ]
@@ -987,6 +1005,7 @@ name = "nvidia-nvtx-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" },
{ url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
] ]
@@ -1393,14 +1412,14 @@ wheels = [
[[package]] [[package]]
name = "sphinx-autodoc-typehints" name = "sphinx-autodoc-typehints"
version = "3.9.5" version = "3.9.7"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "sphinx" }, { name = "sphinx" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/58/ec/21bd9babcfeb9930a73011257002d5cfa5fd30667b8de6d76dbaf8275dfb/sphinx_autodoc_typehints-3.9.5.tar.gz", hash = "sha256:60e646efb7c352a0e98f34dd7fdcde4527fbbdbdf30371ff8321b6b3eb1fd37d", size = 63249, upload-time = "2026-03-02T19:58:07.974Z" } sdist = { url = "https://files.pythonhosted.org/packages/f4/06/da2d9e98b3f7f0df144496e62f453e0025f129bccc7a6076b8ceae6047b1/sphinx_autodoc_typehints-3.9.7.tar.gz", hash = "sha256:70f3dd4e4dd815ae30e5d3848a26dca71fb5e7fcf8f37cf8b840dc8afdf07e82", size = 68689, upload-time = "2026-03-05T18:33:40.829Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7f/cb/80c250f47a0ca5ac67d82f14811b4068a551a12b4790b085ffdb900de427/sphinx_autodoc_typehints-3.9.5-py3-none-any.whl", hash = "sha256:c94f88a90b6c61a7a6686cb77b410e46e077712838387e6cf22d69e85cfd06a5", size = 34763, upload-time = "2026-03-02T19:58:06.028Z" }, { url = "https://files.pythonhosted.org/packages/a4/a0/e7d3365dabfa79a1b2ac7d3122b5b22b401a9c4d5e4eadc5e13b88c63a2c/sphinx_autodoc_typehints-3.9.7-py3-none-any.whl", hash = "sha256:dd73f6a32adef0d8208f6f7d99254e1880259c77db7b4a91648345d45202d48e", size = 36691, upload-time = "2026-03-05T18:33:38.983Z" },
] ]
[[package]] [[package]]
@@ -1542,52 +1561,47 @@ wheels = [
[[package]] [[package]]
name = "torch" name = "torch"
version = "2.10.0" version = "2.10.0+cu128"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://download.pytorch.org/whl/cu128" }
dependencies = [ dependencies = [
{ name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "cuda-bindings", marker = "sys_platform == 'linux'" },
{ name = "filelock" }, { name = "filelock" },
{ name = "fsspec" }, { name = "fsspec" },
{ name = "jinja2" }, { name = "jinja2" },
{ name = "networkx" }, { name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux'" },
{ name = "setuptools" }, { name = "setuptools" },
{ name = "sympy" }, { name = "sympy" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "triton", marker = "sys_platform == 'linux'" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:bdbcc703382f948e951c063448c9406bf38ce66c41dd698d9e2733fcf96c037a" },
{ url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7b4bd23ed63de97456fcc81c26fea9f02ee02ce1112111c4dac0d8cfe574b23e" },
{ url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:4d1b0b49c54223c7c04050b49eac141d77b6edbc34aea1dfc74a6fdb661baa8c" },
{ url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f1f8b840c64b645a4bc61a393db48effb9c92b2dc26c8373873911f0750d1ea7" },
{ url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:23f58258012bcf1c349cb22af387e33aadca7f83ea617b080e774eb41e4fe8ff" },
{ url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:01b216e097b17a5277cfb47c383cdcacf06abeadcb0daca0c76b59e72854c3b6" },
{ url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:c42377bc2607e3e1c60da71b792fb507c3938c87fd6edab8b21c59c91473c36d" },
{ url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:37d71feea068776855686a1512058df3f19f6f040a151f055aa746601678744f" },
{ url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:c57017ca29e62271e362fdeee7d20070e254755a5148b30b553d8a10fc83c7ef" },
{ url = "https://files.pythonhosted.org/packages/4f/93/716b5ac0155f1be70ed81bacc21269c3ece8dba0c249b9994094110bfc51/torch-2.10.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:bf0d9ff448b0218e0433aeb198805192346c4fd659c852370d5cc245f602a06a", size = 79464992, upload-time = "2026-01-21T16:23:05.162Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:777461f50b2daf77e4bdd8e2ad34bdfc5a993bf1bdf2ab9ef39f5edfe4e9c12b" },
{ url = "https://files.pythonhosted.org/packages/69/2b/51e663ff190c9d16d4a8271203b71bc73a16aa7619b9f271a69b9d4a936b/torch-2.10.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:233aed0659a2503b831d8a67e9da66a62c996204c0bba4f4c442ccc0c68a3f60", size = 146018567, upload-time = "2026-01-21T16:22:23.393Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7bcba6a7c5f0987a13298b1ca843155dcceceac758fa3c7ccd5c7af4059a1080" },
{ url = "https://files.pythonhosted.org/packages/5e/cd/4b95ef7f293b927c283db0b136c42be91c8ec6845c44de0238c8c23bdc80/torch-2.10.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:682497e16bdfa6efeec8cde66531bc8d1fbbbb4d8788ec6173c089ed3cc2bfe5", size = 915721646, upload-time = "2026-01-21T16:21:16.983Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:70d89143c956389d4806cb4e5fe0b1129fe0db280e1073288d17fa76c101cba4" },
{ url = "https://files.pythonhosted.org/packages/56/97/078a007208f8056d88ae43198833469e61a0a355abc0b070edd2c085eb9a/torch-2.10.0-cp314-cp314-win_amd64.whl", hash = "sha256:6528f13d2a8593a1a412ea07a99812495bec07e9224c28b2a25c0a30c7da025c", size = 113752373, upload-time = "2026-01-21T16:22:13.471Z" },
{ url = "https://files.pythonhosted.org/packages/d8/94/71994e7d0d5238393df9732fdab607e37e2b56d26a746cb59fdb415f8966/torch-2.10.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:f5ab4ba32383061be0fb74bda772d470140a12c1c3b58a0cfbf3dae94d164c28", size = 79850324, upload-time = "2026-01-21T16:22:09.494Z" },
{ url = "https://files.pythonhosted.org/packages/e2/65/1a05346b418ea8ccd10360eef4b3e0ce688fba544e76edec26913a8d0ee0/torch-2.10.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:716b01a176c2a5659c98f6b01bf868244abdd896526f1c692712ab36dbaf9b63", size = 146006482, upload-time = "2026-01-21T16:22:18.42Z" },
{ url = "https://files.pythonhosted.org/packages/1d/b9/5f6f9d9e859fc3235f60578fa64f52c9c6e9b4327f0fe0defb6de5c0de31/torch-2.10.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:d8f5912ba938233f86361e891789595ff35ca4b4e2ac8fe3670895e5976731d6", size = 915613050, upload-time = "2026-01-21T16:20:49.035Z" },
{ url = "https://files.pythonhosted.org/packages/66/4d/35352043ee0eaffdeff154fad67cd4a31dbed7ff8e3be1cc4549717d6d51/torch-2.10.0-cp314-cp314t-win_amd64.whl", hash = "sha256:71283a373f0ee2c89e0f0d5f446039bdabe8dbc3c9ccf35f0f784908b0acd185", size = 113995816, upload-time = "2026-01-21T16:22:05.312Z" },
] ]
[[package]] [[package]]
@@ -1623,7 +1637,7 @@ wheels = [
[[package]] [[package]]
name = "trainlib" name = "trainlib"
version = "0.1.0" version = "0.1.2"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "colorama" }, { name = "colorama" },
@@ -1662,7 +1676,7 @@ requires-dist = [
{ name = "sphinx-autodoc-typehints", marker = "extra == 'doc'" }, { name = "sphinx-autodoc-typehints", marker = "extra == 'doc'" },
{ name = "sphinx-togglebutton", marker = "extra == 'doc'" }, { name = "sphinx-togglebutton", marker = "extra == 'doc'" },
{ name = "tensorboard", specifier = ">=2.20.0" }, { name = "tensorboard", specifier = ">=2.20.0" },
{ name = "torch", specifier = ">=2.5.1" }, { name = "torch", index = "https://download.pytorch.org/whl/cu128" },
{ name = "tqdm", specifier = ">=4.67.1" }, { name = "tqdm", specifier = ">=4.67.1" },
] ]
provides-extras = ["dev", "doc", "test"] provides-extras = ["dev", "doc", "test"]
@@ -1681,9 +1695,13 @@ name = "triton"
version = "3.6.0" version = "3.6.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" },
{ url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" },
{ url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" },
{ url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" },
{ url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" },
{ url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" },
{ url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" },
{ url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" },
] ]