Compare commits
4 Commits
c473e48b5b
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| c2e4294c8c | |||
| e867bc0e7f | |||
| faeef9c72a | |||
| 805262dfc4 |
21
README.md
21
README.md
@@ -1,17 +1,12 @@
|
||||
# Overview
|
||||
Package summary goes here, ideally with a diagram
|
||||
Minimal framework for ML modeling, supporting advanced dataset operations and
|
||||
streamlined training workflows.
|
||||
|
||||
# Install
|
||||
Installation instructions
|
||||
The `trainlib` package can be installed from PyPI:
|
||||
|
||||
```sh
|
||||
pip install <package>
|
||||
```
|
||||
|
||||
or as a CLI tool
|
||||
|
||||
```sh
|
||||
uv tool install <package>
|
||||
pip install trainlib
|
||||
```
|
||||
|
||||
# Development
|
||||
@@ -20,16 +15,16 @@ uv tool install <package>
|
||||
- Depending on needs, install the development dependencies with `uv sync
|
||||
--extra dev`.
|
||||
|
||||
# Testing
|
||||
## Testing
|
||||
- To run the unit tests, make sure to first have the test dependencies
|
||||
installed with `uv sync --extra test`, then run `make test`.
|
||||
- For notebook testing, run `make install-kernel` to make the environment
|
||||
available as a Jupyter kernel (to be selected when running notebooks).
|
||||
|
||||
# Documentation
|
||||
## Documentation
|
||||
- Install the documentation dependencies with `uv sync --extra doc`.
|
||||
- Run `make docs-build` (optionally preceded by `make docs-clean`), and serve
|
||||
locally with `docs-serve`.
|
||||
locally with `make docs-serve`.
|
||||
|
||||
# Development remarks
|
||||
- Across `Trainer` / `Estimator` / `Dataset`, I've considered a
|
||||
@@ -91,7 +86,7 @@ uv tool install <package>
|
||||
class SequenceDataset[I, **P](HomogenousDataset[int, I, I, P]):
|
||||
...
|
||||
|
||||
class TupleDataset[I](SequenceDataset[tuple[I, ...], ??]):
|
||||
class TupleDataset[I](SequenceDataset[tuple[I, ...], "?"]):
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
20
doc/Makefile
Normal file
20
doc/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
9
doc/_templates/autosummary.md
vendored
Normal file
9
doc/_templates/autosummary.md
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# {{ fullname | escape }}
|
||||
|
||||
```{automodule}
|
||||
{{ fullname }}
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:imported-members:
|
||||
```
|
||||
8
doc/_templates/autosummary/module.rst
vendored
Normal file
8
doc/_templates/autosummary/module.rst
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. automodule:: {{ fullname }}
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:imported-members:
|
||||
|
||||
112
doc/conf.py
Normal file
112
doc/conf.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
|
||||
# -- Styling: type hints ------------------------------------------------------
|
||||
# There are several possible style combinations for rendering types, none of
|
||||
# which are optimal in my view. The main switches are:
|
||||
#
|
||||
# - Parameter type hints in the signature vs in the separate parameter list
|
||||
# - Show type hints as plaintext vs rendered HTML elements
|
||||
#
|
||||
# The `sphinx_autodoc_typehints` extension enables more context-aware
|
||||
# rendering, but it's often way too explicit (e.g., unwrapping type variables)
|
||||
# and makes things difficult to read. It does, however, allow for automatic
|
||||
# inclusion of default values, which is nice.
|
||||
#
|
||||
# I'd like type hints to be rendered in an inline code element, but that
|
||||
# doesn't happen by default in either case unless you render them in the
|
||||
# signature. This is sloppy, however, often just a jumbled mess or parameter
|
||||
# names and types. The current preferred option is to just use the native
|
||||
# `autodoc` settings for rendering type hints, leaving them out of the
|
||||
# signature (for easy heading readability). Type hints in the parameter list
|
||||
# are also as short as possible, not rendered crazily (by default this is in
|
||||
# italics; not my favorite but it's what we have). No
|
||||
# `sphinx_autodoc_typehints` needed at this point; you can toggle it if you
|
||||
# want automatic default values or different formatting for type hints.
|
||||
|
||||
|
||||
# -- Project information ------------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = "trainlib"
|
||||
copyright = "2026, Sam Griesemer"
|
||||
author = "Sam Griesemer"
|
||||
|
||||
|
||||
# -- General configuration ----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
|
||||
# enables a directive to be specified manually that gathers module/object
|
||||
# summary details in a table
|
||||
"sphinx.ext.autosummary",
|
||||
|
||||
# allow viewing source in the HTML pages
|
||||
"sphinx.ext.viewcode",
|
||||
|
||||
# only really applies to manual docs; docstrings still need RST-like
|
||||
"myst_parser",
|
||||
|
||||
# enables Google-style docstring formats
|
||||
"sphinx.ext.napoleon",
|
||||
|
||||
# external extension that allows arg types to be inferred by type hints;
|
||||
# without this, type hints show up inside method signatures as plaintext,
|
||||
# but when enabled they are pulled into the parameter/description block and
|
||||
# rendered as native nested markup. What's best for a given package may
|
||||
# vary.
|
||||
# "sphinx_autodoc_typehints",
|
||||
]
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
# include __init__ definitions in autodoc
|
||||
autodoc_default_options = {
|
||||
"special-members": "__init__",
|
||||
}
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for autodoc ------------------------------------------------------
|
||||
# class signatures show up only in __init__ rather than at the class header;
|
||||
# generally cleaner, avoids redundancy
|
||||
autodoc_class_signature = "separated"
|
||||
|
||||
# if `sphinx_autodoc_typehints` extension is enabled, this is redundant: type
|
||||
# hints are rendered natively and already show up in the parameter block. If
|
||||
# it's disabled, this setting will do the same job of moving the types to the
|
||||
# parameter block, but it renders them in plaintext (with links to in-package
|
||||
# type refs).
|
||||
autodoc_typehints = "description" # "signature"
|
||||
autodoc_typehints_format = "short"
|
||||
autodoc_preserve_defaults = True
|
||||
autodoc_use_type_comments = False
|
||||
python_use_unqualified_type_names = True
|
||||
|
||||
# push parameters to their own lines in the signature block
|
||||
# python_maximum_signature_line_length = 60
|
||||
|
||||
|
||||
# -- Options for autodoc_typehints --------------------------------------------
|
||||
# always_use_bars_union = True # always on for Python 3.14+
|
||||
# typehints_defaults = "braces-after" # render defaults in param block
|
||||
# typehints_use_signature = False # False is default; enable if wanted in sig
|
||||
# always_document_param_types = True # show types even when not in docstring
|
||||
|
||||
|
||||
# -- Options for HTML output --------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "furo" # "pydata_sphinx_theme"
|
||||
html_static_path = ["_static"]
|
||||
|
||||
# html_sidebars = {
|
||||
# '**': ['/modules.html'],
|
||||
# }
|
||||
37
doc/index.md
Normal file
37
doc/index.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# `trainlib` package docs
|
||||
|
||||
{ref}`genindex`
|
||||
{ref}`modindex`
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
:recursive:
|
||||
:caption: Modules
|
||||
|
||||
trainlib.dataset
|
||||
trainlib.domain
|
||||
trainlib.estimator
|
||||
trainlib.trainer
|
||||
trainlib.transform
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 3
|
||||
:caption: Autoref
|
||||
:hidden:
|
||||
|
||||
_autoref/trainlib.rst
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 3
|
||||
:caption: Contents
|
||||
:hidden:
|
||||
|
||||
reference/documentation/index
|
||||
```
|
||||
|
||||
```{include} ../README.md
|
||||
:heading-offset: 1
|
||||
```
|
||||
35
doc/make.bat
Normal file
35
doc/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
||||
5
doc/reference/documentation/index.md
Normal file
5
doc/reference/documentation/index.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Documentation
|
||||
|
||||
```{toctree}
|
||||
sphinx
|
||||
```
|
||||
111
doc/reference/documentation/sphinx.md
Normal file
111
doc/reference/documentation/sphinx.md
Normal file
@@ -0,0 +1,111 @@
|
||||
# Sphinx
|
||||
The primary driver of this package's documentation is Sphinx's `autodoc` extension,
|
||||
using the [Furo theme][1].
|
||||
|
||||
**High-level details**:
|
||||
|
||||
- `sphinx-apidoc` generates package-based documentation to the `_autoref/` directory,
|
||||
with navigation available under "Autoref" in the sidebar.
|
||||
- Markdown-based documentation files are manually written under the `reference/`
|
||||
directory, showing up under "Contents" in the sidebar.
|
||||
|
||||
## Detailed directory structure
|
||||
All files are placed under `docs/sphinx`:
|
||||
|
||||
- `_`-prefixed are Sphinx-managed directories
|
||||
* `_build/html/` houses output HTML files
|
||||
* `_autoref/` is the target for module-based RST files written by `autodoc`
|
||||
- `reference/`: houses all manually written documentation (totally separate from
|
||||
auto-generated package docs)
|
||||
- `conf.py`: single Sphinx configuration file
|
||||
- `index.md`: documentation index, setups up a persistent sidebar across all other pages
|
||||
|
||||
For manually written documentation under `reference/`, topics are nested as needed. Within
|
||||
a nested directory `reference/<topic>`, an `index.md` should created with content like:
|
||||
|
||||
```
|
||||
# <Topic>
|
||||
|
||||
\`\`\`{toctree}
|
||||
:hidden:
|
||||
|
||||
sub-topic-1.rst
|
||||
sub-topic-2.rst
|
||||
...
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
This will add the nested directory to the sidebar navigation, using the name set under the
|
||||
top-level header. See [Markdown syntax][#markdown-syntax] for more details on the syntax.
|
||||
|
||||
## Sphinx autodoc
|
||||
Sphinx's `autodoc` extension allows automatic generation of documents according to
|
||||
(Python) subpackage structure and available docstrings. A few notes here:
|
||||
|
||||
- In the `conf.py` file, autodoc is enabled by adding `"sphinx.ext.autodoc"` to
|
||||
the extensions list. `"sphinx.ext.viewcode"` can also be added to provide
|
||||
links to source code.
|
||||
- Documents are actually generated by calling the `sphinx-apidoc` CLI command. The
|
||||
current Makefile uses the following call:
|
||||
|
||||
```sh
|
||||
sphinx-apidoc --module-first -o docs/sphinx/_autoref/ localsys
|
||||
```
|
||||
|
||||
This writes the automatically generated docs for modules in the package at the
|
||||
local directory `localsys/` to the `docs/sphinx/_autoref` directory. These are
|
||||
reStructuredText files by default.
|
||||
* `--module-first` places the module-level descriptions at the top of the module page.
|
||||
By default, this is placed at the bottom (oddly), and can be obscured by large lists
|
||||
of subpackages if this flag isn't provided.
|
||||
* See available `sphinx-apidoc` options [here][2], as well as more advanced config
|
||||
[here][3].
|
||||
|
||||
|
||||
## Markdown syntax
|
||||
The `myst_parser` extension enables Markdown (or something close to it) to be used when
|
||||
writing documentation files. The Sphinx directives can be difficult to track, and
|
||||
they change slightly under the MyST Markdown syntax. The following are a few common
|
||||
blocks:
|
||||
|
||||
**Page hierarchies**: the following will generate link hierarchy according to the provided
|
||||
pages:
|
||||
|
||||
```
|
||||
\`\`\`{toctree}
|
||||
:maxdepth: <n>
|
||||
:caption: <caption>
|
||||
:hidden:
|
||||
|
||||
example-file-1
|
||||
example-file-2
|
||||
example-dir/index
|
||||
...
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
- `:maxdepth:` limits the depth of nesting
|
||||
- `:caption:` title for the group of pages
|
||||
- `:hidden:` if provided, links will only show in the sidebar (hidden on the page)
|
||||
- Constituent files: listed files will be rendered as a link directly. If a listed file
|
||||
has a `{toctree}` directive, this tree will be rendered in place of the page's link as a
|
||||
dropdown. The dropdown will be named according to the file's top-level heading, and
|
||||
clicking directly on the dropdown header will show that page's content. Files found in
|
||||
the tree will be placed as links under the dropdown, recursively subject to same rules
|
||||
described here.
|
||||
|
||||
**Include files**: the following will include file content
|
||||
pages:
|
||||
|
||||
```
|
||||
\`\`\`{include} README.md
|
||||
\`\`\`
|
||||
```
|
||||
|
||||
**Reference directives**
|
||||
|
||||
|
||||
[1]: https://pradyunsg.me/furo/
|
||||
[2]: https://www.sphinx-doc.org/en/master/man/sphinx-apidoc.html
|
||||
[3]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#
|
||||
|
||||
20
example/dataset.py
Normal file
20
example/dataset.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from trainlib.domain import SequenceDomain
|
||||
from trainlib.datasets.memory import TupleDataset
|
||||
|
||||
|
||||
class Record(NamedTuple):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
tl_domain = SequenceDomain[Record]([
|
||||
Record(1, "1"),
|
||||
Record(2, "2"),
|
||||
])
|
||||
|
||||
class R0(TupleDataset[Record]):
|
||||
item_tuple = Record
|
||||
|
||||
def _process_item_data(self, item_data, item_index):
|
||||
return (item_data[0],)
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "trainlib"
|
||||
version = "0.1.0"
|
||||
version = "0.1.2"
|
||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||
requires-python = ">=3.13"
|
||||
authors = [
|
||||
@@ -24,11 +24,11 @@ classifiers = [
|
||||
"Intended Audience :: End Users/Desktop",
|
||||
]
|
||||
dependencies = [
|
||||
"torch",
|
||||
"colorama>=0.4.6",
|
||||
"matplotlib>=3.10.8",
|
||||
"numpy>=2.4.1",
|
||||
"tensorboard>=2.20.0",
|
||||
"torch>=2.5.1",
|
||||
"tqdm>=4.67.1",
|
||||
]
|
||||
|
||||
@@ -41,6 +41,7 @@ dev = [
|
||||
]
|
||||
doc = [
|
||||
"furo",
|
||||
# "pydata-sphinx-theme",
|
||||
"myst-parser",
|
||||
"sphinx",
|
||||
"sphinx-togglebutton",
|
||||
@@ -82,3 +83,11 @@ force-sort-within-sections = false
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
@@ -1,121 +1,62 @@
|
||||
"""
|
||||
Marginalizing out the modality layer:
|
||||
Domain-generic dataset base with attribute-based splitting and balancing
|
||||
|
||||
With ``domain`` being an instance variable, one possible interpretation of
|
||||
the object structures here is that one could completely abstract away
|
||||
the domain model, defining only item structures and processing data. You
|
||||
could have a single dataset definition for a particular concrete dataset,
|
||||
and so long as we're talking about the same items, it can be instantiated
|
||||
using *any domain*. You wouldn't need specific subclasses for disk or
|
||||
network or in-memory; you can tell it directly at runtime.
|
||||
**Marginalizing out the modality layer**
|
||||
|
||||
That's an eventually possibility, anyway. As it stands, however, this is
|
||||
effectively impossible:
|
||||
With ``domain`` being an instance variable, one possible interpretation of
|
||||
the object structures here is that one could completely abstract away
|
||||
the domain model, defining only item structures and processing data. You
|
||||
could have a single dataset definition for a particular concrete dataset,
|
||||
and so long as we're talking about the same items, it can be instantiated
|
||||
using *any domain*. You wouldn't need specific subclasses for disk or
|
||||
network or in-memory structures; you can tell it directly at runtime.
|
||||
|
||||
You can't easily abstract the batch -> item splitting process, i.e.,
|
||||
``_process_batch_data()``. A list-based version of the dataset you're
|
||||
trying to define might have an individual item tuple at every index,
|
||||
whereas a disk-based version might have tuples batched across a few files.
|
||||
This can't reliably be inferred, nor can it be pushed to the
|
||||
``Domain``-level without needing equal levels of specialization (you'd just
|
||||
end up needing the exact same structural distinctions in the ``Domain``
|
||||
hierarchy). So *somewhere* you need a batch splitting implementation that
|
||||
is both item structure-dependent *and* domain-dependent...the question is
|
||||
how dynamic you're willing to be about where it comes from. Right now, we
|
||||
require this actually be defined in the ``_process_batch_data()`` method,
|
||||
meaning you'll need a specific ``Dataset`` class for each domain you want
|
||||
to support (e.g., ``MNISTDisk``, ``MNISTList``, ``MNISTNetwork``, etc), or
|
||||
at least for each domain where "interpreting" a batch could possibly
|
||||
differ. This is a case where the interface is all that enforces a
|
||||
distinction: if you've got two domains that can be counted on to yield
|
||||
batches in the exact same way and can use the same processing, then you
|
||||
could feasibly provide ``Domain`` objects from either at runtime and have
|
||||
no issues. We're "structurally blind" to any differentiation beyond the URI
|
||||
and resource types by design, so two different domain implementations with
|
||||
the same type signature ``Domain[U, R]`` should be expected to work fine at
|
||||
runtime (again, so long as they don't also need different batch
|
||||
processing), but that's not affording us much flexibility, i.e., most of
|
||||
the time we'll still be defining new dataset classes for each domain.
|
||||
That's an eventually possibility, anyway. As it stands, however, this is
|
||||
effectively impossible:
|
||||
|
||||
I initially flagged this as feasible, however, because one could imagine
|
||||
accepting a batch processing method upon instantiation rather than
|
||||
structurally bolting it into the ``Dataset`` definition. This would require
|
||||
knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so
|
||||
such a function will always have to be (I, U, R)-dependent. It nevertheless
|
||||
would take out some of the pain of having to define new dataset classes;
|
||||
instead, you'd just need to define the batch processing method. I see this
|
||||
as a worse alternative to just defining *inside* a safe context like a new
|
||||
dataset class: you know the types you have to respect, and you stick that
|
||||
method exactly in a context where it's understood. Freeing this up doesn't
|
||||
lighten the burden of processing logic, it just changes *when* it has to be
|
||||
provided, and that's not worth much (to me) in this case given the bump in
|
||||
complexity. (Taking this to the extreme: you could supply *all* of an
|
||||
object's methods "dynamically" and glue them together at runtime so long as
|
||||
they all played nice. But wherever you were "laying them out" beforehand is
|
||||
exactly the job of a class to begin with, so you don't end up with anything
|
||||
more dynamic. All we're really discussing here is pushing around
|
||||
unavoidable complexity inside and outside of the "class walls," and in the
|
||||
particular case of ``_process_batch_data()``, it feels much better when
|
||||
it's on the inside.)
|
||||
You can't easily abstract the batch-to-item splitting process, i.e.,
|
||||
``_process_batch_data()``. A list-based version of the dataset you're trying to
|
||||
define might have an individual item tuple at every index, whereas a disk-based
|
||||
version might have tuples batched across a few files. This can't reliably be
|
||||
inferred, nor can it be pushed to the ``Domain``-level without needing equal
|
||||
levels of specialization (you'd just end up needing the exact same structural
|
||||
distinctions in the ``Domain`` hierarchy). So *somewhere* you need a batch
|
||||
splitting implementation that is both item structure-dependent *and*
|
||||
domain-dependent...the question is how dynamic you're willing to be about where
|
||||
it comes from. Right now, we require this actually be defined in the
|
||||
``_process_batch_data()`` method, meaning you'll need a specific ``Dataset``
|
||||
class for each domain you want to support (e.g., ``MNISTDisk``, ``MNISTList``,
|
||||
``MNISTNetwork``, etc), or at least for each domain where "interpreting" a
|
||||
batch could possibly differ. This is a case where the interface is all that
|
||||
enforces a distinction: if you've got two domains that can be counted on to
|
||||
yield batches in the exact same way and can use the same processing, then you
|
||||
could feasibly provide ``Domain`` objects from either at runtime and have no
|
||||
issues. We're "structurally blind" to any differentiation beyond the URI and
|
||||
resource types by design, so two different domain implementations with the same
|
||||
type signature ``Domain[U, R]`` should be expected to work fine at runtime
|
||||
(again, so long as they don't also need different batch processing), but that's
|
||||
not affording us much flexibility, i.e., most of the time we'll still be
|
||||
defining new dataset classes for each domain.
|
||||
|
||||
Holding:
|
||||
@abstractmethod
|
||||
def _get_uri_groups(self) -> Iterable[tuple[U, ...]]:
|
||||
Get URI groups for each batch.
|
||||
|
||||
If there's more than one URI per batch (e.g., a data file and a
|
||||
metadata file), zip the URIs such that we have a tuple of URIs per
|
||||
batch.
|
||||
|
||||
Note that this effectively defines the index style over batches in the
|
||||
attached domain. We get an ``int -> tuple[U, ...]`` map that turns
|
||||
batch indices into URIs that can be read under the domain.
|
||||
``get_batch()`` turns an integer index into its corresponding
|
||||
``tuple[U, ...]``, reading the resources with ``_read_resources()`` in
|
||||
the tuple, treating them as providers of batched data.
|
||||
``_read_resources()`` passes through to the attached domain logic,
|
||||
which, although common, need not supply an explicit iterable of batch
|
||||
items: we just access items with ``__getitem__()`` and may ask for
|
||||
``__len__``. So the returned URI group collection (this method) does
|
||||
need to be iterable to measure the number of batches, but the batch
|
||||
objects that are ultimately produced by these URI groups need not be
|
||||
iterables themselves.
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def _read_resources(
|
||||
self,
|
||||
uri_group: tuple[U, ...],
|
||||
batch_index: int
|
||||
) -> tuple[R, ...]:
|
||||
Read batch files at the provided paths.
|
||||
|
||||
This method should operate on a single tuple from the list of batch
|
||||
tuples returned by the ``_get_uri_groups()`` method. That is, it reads
|
||||
all of the resources for a single batch and returns a tuple of the same
|
||||
size with their contents.
|
||||
|
||||
Note: the dependence on a batch index is mostly here to make
|
||||
multi-dataset composition easier later. In-dataset, you don't need to
|
||||
know the batch index to to simply process URIs, but across datasets you
|
||||
need it to find out the origin of the batch (and process those URIs
|
||||
accordingly).
|
||||
|
||||
return tuple(self.domain.read(uri) for uri in uri_group)
|
||||
|
||||
# pulling the type variable out of the inline generic b/c `ty` has trouble
|
||||
# understanding bound type variables in subclasses (specifically with Self@)
|
||||
T = TypeVar("T", bound=NamedTuple)
|
||||
|
||||
class NamedTupleDataset[I](Dataset):
|
||||
def __init__(self, data_list: list[I]) -> None:
|
||||
self.data_list = data_list
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data_list)
|
||||
|
||||
def __getitem__(self, index: int) -> I:
|
||||
return self.data_list[index]
|
||||
I initially flagged this as feasible, however, because one could imagine
|
||||
accepting a batch processing method upon instantiation rather than structurally
|
||||
bolting it into the ``Dataset`` definition. This would require knowledge of the
|
||||
item structure ``I`` as well as the ``Domain[U, R]``, so such a function will
|
||||
always have to be ``(I, U, R)``-dependent. It nevertheless would take out some
|
||||
of the pain of having to define new dataset classes; instead, you'd just need
|
||||
to define the batch processing method. I see this as a worse alternative to
|
||||
just defining *inside* a safe context like a new dataset class: you know the
|
||||
types you have to respect, and you stick that method exactly in a context where
|
||||
it's understood. Freeing this up doesn't lighten the burden of processing
|
||||
logic, it just changes *when* it has to be provided, and that's not worth much
|
||||
(to me) in this case given the bump in complexity. (Taking this to the extreme:
|
||||
you could supply *all* of an object's methods "dynamically" and glue them
|
||||
together at runtime so long as they all played nice. But wherever you were
|
||||
"laying them out" beforehand is exactly the job of a class to begin with, so
|
||||
you don't end up with anything more dynamic. All we're really discussing here
|
||||
is pushing around unavoidable complexity inside and outside of the "class
|
||||
walls," and in the particular case of ``_process_batch_data()``, it feels much
|
||||
better when it's on the inside.)
|
||||
"""
|
||||
|
||||
import math
|
||||
@@ -154,40 +95,77 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
|
||||
The class is generic over a URI type ``U``, a resource type ``R`` (both of
|
||||
which are used to concretize a domain ``Domain[U, R]``), and an item type
|
||||
``T`` (which has a ``tuple`` upper bound).
|
||||
``I``.
|
||||
|
||||
Pipeline overview:
|
||||
**Batch and item processing flow**
|
||||
|
||||
```
|
||||
Domain -> [U] (get _batch_uris)
|
||||
U -> R (domain access ; Rs provide batches)
|
||||
R -> [I] (cache here ; _process_batch_data to use load_transform)
|
||||
[I] -> I (human item obj ; _get_item)
|
||||
I -> **P (final packed item ; __getitem__ to use transform)
|
||||
```
|
||||
.. code-block:: text
|
||||
|
||||
Note^1: as far as positioning, this class is meant to play nice with
|
||||
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
|
||||
value add for this over the ``torch.Dataset`` base is almost entirely in
|
||||
the logic it implements to map out of *batched resources* that are holding
|
||||
data, and flattening it out into typical dataset items. There are also some
|
||||
QoL items when it comes to splitting and balancing samples.
|
||||
Domain -> [U] :: self._batch_uris = list(domain)
|
||||
|
||||
Note^2: even though ``Domains`` implement iterators over their URIs, this
|
||||
doesn't imply a ``BatchedDataset`` is iterable. This just means we can walk
|
||||
over the resources that provide data, but we don't necessarily presuppose
|
||||
an ordered walk over samples within batches. Point being:
|
||||
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
||||
superclass, even when we're working around iterable ``Domains``.
|
||||
Grab all URIs from Domain iterators. This is made concrete early to
|
||||
allow for Dataset sizing, and we need a Sequence representation to
|
||||
map integer batch indices into Domains, i.e., when getting the
|
||||
corresponding URI:
|
||||
|
||||
Note^3: transforms are expected to operate on ``I``-items and produce
|
||||
``I``-items. They shouldn't be the "introducers" of ``I`` types from some
|
||||
other intermediate representation, nor should they map from ``I`` to
|
||||
something else. Point being: the dataset definition should be able to map
|
||||
resources ``R`` to ``I`` without a transform: that much should be baked
|
||||
into the class definition. If you find you're expecting the transform to do
|
||||
that for you, you should consider pulling in some common structure across
|
||||
the allowed transforms and make it a fixed part of the class.
|
||||
batch_uri = self._batch_uris[batch_index]
|
||||
|
||||
We let Domains implement iterators over their URIs, but explicitly
|
||||
exhaust when initializing Datasets.
|
||||
|
||||
U -> R :: batch_data = self.domain[batch_uri]
|
||||
|
||||
Retrieve resource from domain. Resources are viewed as batched
|
||||
data, even if only wrapping single items (happens in trivial
|
||||
settings).
|
||||
|
||||
R -> [I] :: self._process_batch_data(batch_data, batch_index)
|
||||
|
||||
Possibly domain-specific batch processing of resource data into
|
||||
explicit Sequence-like structures of items, each of which is
|
||||
subject to the provided pre_transform. Processed batches at this
|
||||
stage are cached (if enabled).
|
||||
|
||||
[I] -> I :: self.get_batch(batch_index)[index_in_batch]
|
||||
|
||||
Select individual items from batches in _get_item. At this stage,
|
||||
items are in intermediate states and pulled from the cached
|
||||
batches.
|
||||
|
||||
I -> I :: self._process_item_data(item_data, index)
|
||||
|
||||
Produce final items with __getitem__, getting intermediate items
|
||||
via _get_item and applying the provided post_transform.
|
||||
|
||||
.. note::
|
||||
|
||||
As far as positioning, this class is meant to play nice with PyTorch
|
||||
DataLoaders, hence the inheritance from ``torch.Dataset``. The value
|
||||
add for this over the ``torch.Dataset`` base is almost entirely in the
|
||||
logic it implements to map out of *batched resources* that are holding
|
||||
data, and flattening it out into typical dataset items. There are also
|
||||
some QoL features when it comes to splitting and balancing samples.
|
||||
|
||||
.. note::
|
||||
|
||||
Even though ``Domains`` implement iterators over their URIs, this
|
||||
doesn't imply a ``BatchedDataset`` is iterable. This just means we can
|
||||
walk over the resources that provide data, but we don't necessarily
|
||||
presuppose an ordered walk over samples within batches. Point being:
|
||||
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
||||
superclass, even when we're working around iterable ``Domains``.
|
||||
|
||||
.. note::
|
||||
|
||||
Transforms are expected to operate on ``I``-items and produce
|
||||
``I``-items. They shouldn't be the "introducers" of ``I`` types from
|
||||
some other intermediate representation, nor should they map from ``I``
|
||||
to something else. Point being: the dataset definition should be able
|
||||
to map resources ``R`` to ``I`` without a transform: that much should
|
||||
be baked into the class definition. If you find you're expecting the
|
||||
transform to do that for you, you should consider pulling in some
|
||||
common structure across the allowed transforms and make it a fixed part
|
||||
of the class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -201,6 +179,7 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
domain: ``Domain`` object providing access to batched data
|
||||
pre_transform: transform to apply over items during loading (in
|
||||
``_process_batch_data()``), i.e., *before* going into
|
||||
persistent storage
|
||||
@@ -210,6 +189,7 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
batch_cache_limit: the max number of max batches to cache at any
|
||||
one time
|
||||
preload: whether to load all data into memory during instantiation
|
||||
num_workers: number of workers to use when preloading data
|
||||
"""
|
||||
|
||||
self.domain = domain
|
||||
@@ -249,6 +229,9 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
The behavior of this method can vary depending on what we know about
|
||||
batch sizes, and should therefore be implemented by inheriting classes.
|
||||
|
||||
Parameters:
|
||||
item_index: index of item
|
||||
|
||||
Returns:
|
||||
batch_index: int
|
||||
index_in_batch: int
|
||||
@@ -292,6 +275,10 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
place to use a provided ``post_transform``; items are pulled from the
|
||||
cache (if enabled) and processed before being returned as the final
|
||||
tuple outputs (so this processing is not persistent).
|
||||
|
||||
Parameters:
|
||||
item_data: item data
|
||||
item_index: index of item
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -307,6 +294,9 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
|
||||
Note that return values from `__getitem__()` are "cleaned up" versions
|
||||
of this representation, with minimal info needed for training.
|
||||
|
||||
Parameters:
|
||||
item_index: index of item
|
||||
"""
|
||||
|
||||
if item_index >= len(self):
|
||||
@@ -341,10 +331,13 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
they're always connected, and nothing would notice if you waited
|
||||
between steps. The only way this could matter is if you split the
|
||||
resource reading and batch processing steps across methods, but when it
|
||||
actually comes to accessing/caching the batch, you'd have to expand
|
||||
any delayed reads here. There's no way around needing to see all batch
|
||||
data at once here, and we don't want to make that ambiguous: ``list``
|
||||
output type it is.
|
||||
actually comes to accessing/caching the batch, you'd have to expand any
|
||||
delayed reads here. There's no way around needing to see all batch data
|
||||
at once here, and we don't want to make that ambiguous: ``list`` output
|
||||
type it is.
|
||||
|
||||
Parameters:
|
||||
batch_index: index of batch
|
||||
"""
|
||||
|
||||
logger.debug("Batch cache miss, reading from root...")
|
||||
@@ -364,6 +357,9 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
Can be useful when dynamically pulling data (as it's requested) isn't
|
||||
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
|
||||
continually remove previous batches as they're loaded.
|
||||
|
||||
Parameters:
|
||||
num_workers: number of parallel workers to use for data loading
|
||||
"""
|
||||
|
||||
assert self.batch_cache_limit is None, "Preloading under cache limit"
|
||||
@@ -396,36 +392,46 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
"""
|
||||
Split dataset into fractional pieces by data attribute.
|
||||
|
||||
If `by_attr` is None, recovers typical fractional splitting of dataset
|
||||
items, partitioning by size. Using None anywhere will index each item
|
||||
into its own bucket, i.e., by its index. For instance,
|
||||
If ``by_attr`` is None, recovers typical fractional splitting of
|
||||
dataset items, partitioning by size. Using None anywhere will index
|
||||
each item into its own bucket, i.e., by its index. For instance:
|
||||
|
||||
- by_attr=["color"] -> {("red", 1), ("red", 2)},
|
||||
{("blue", 1), ("blue", 2)}
|
||||
- Splits on the attribute such that each subset contains entire strata
|
||||
of the attribute. "Homogeneity within clusters:"
|
||||
|
||||
Splits on the attribute such that each subset contains entire strata
|
||||
of the attribute. "Homogeneity within clusters"
|
||||
.. code-block::
|
||||
|
||||
- `by_attr=["color", None]` -> {("red", 1), ("blue", 1)},
|
||||
{("red", 2), ("blue", 2)}
|
||||
by_attr=["color"] -> {("red", 1), ("red", 2)},
|
||||
{("blue", 1), ("blue", 2)}
|
||||
|
||||
Stratifies by attribute and then splits "by index" within, uniformly
|
||||
- Stratifies by attribute and then splits "by index" within, uniformly
|
||||
grabbing samples across strata to form new clusters. "Homogeneity
|
||||
across clusters"
|
||||
|
||||
.. code-block::
|
||||
|
||||
by_attr=["color", None] -> {("red", 1), ("blue", 1)},
|
||||
{("red", 2), ("blue", 2)}
|
||||
|
||||
Note that the final list of Subsets returned are built from shallow
|
||||
copies of the underlying dataset (i.e., `self`) to allow manual
|
||||
copies of the underlying dataset (i.e., ``self``) to allow manual
|
||||
intervention with dataset attributes (e.g., setting the splits to have
|
||||
different `transform`s). This is subject to possibly unexpected
|
||||
different ``transforms``). This is subject to possibly unexpected
|
||||
behavior if re-caching data or you need a true copy of all data in
|
||||
memory, but should otherwise leave most interactions unchanged.
|
||||
|
||||
Parameters:
|
||||
frac: split fractions for datasets
|
||||
dataset: dataset to split, defaults to ``self``. Facilitates
|
||||
recursive splitting when multi-attribute splits are needed.
|
||||
by_attr: attribute or attributes to use when grouping strata for
|
||||
dataset splits. Defaults to ``None``, which will use item
|
||||
indices.
|
||||
shuffle_strata: shuffle the strata order before split is drawn. We
|
||||
parameterize this because a dataloader-level shuffle operation
|
||||
parameterize this because a Dataloader-level shuffle operation
|
||||
will only change the order of the indices in the resulting
|
||||
splits; only a shuffle of items inside the strata can change
|
||||
the actual content of the splits themselves.
|
||||
splits; only a shuffle of the strata order can change the
|
||||
actual content of the splits themselves.
|
||||
"""
|
||||
|
||||
if by_attr == []:
|
||||
@@ -534,6 +540,32 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
split_max_sizes: list[int] | None = None,
|
||||
shuffle_strata: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Balance the distribution of provided attributes over dataset items.
|
||||
|
||||
This method sets the indices over the dataset according to the result
|
||||
of the rebalancing. The indices are produced by the recursive
|
||||
``_balance()`` method, which is necessarily separate due to the need
|
||||
for a contained recursive approach that doesn't change the underlying
|
||||
dataset during execution.
|
||||
|
||||
Parameters:
|
||||
dataset: dataset to split, defaults to ``self``. Facilitates
|
||||
recursive splitting when multi-attribute splits are needed.
|
||||
by_attr: attribute or attributes to use when grouping strata for
|
||||
dataset splits. Defaults to ``None``, which will use item
|
||||
indices.
|
||||
split_min_sizes: minimum allowed sizes of splits. Must have the
|
||||
same length as ``by_attr``.
|
||||
split_max_sizes: maximum allowed sizes of splits. Must have the
|
||||
same length as ``by_attr``.
|
||||
shuffle_strata: shuffle the strata order before split is drawn. We
|
||||
parameterize this because a Dataloader-level shuffle operation
|
||||
will only change the order of the indices in the resulting
|
||||
splits; only a shuffle of the strata order can change the
|
||||
actual content of the splits themselves.
|
||||
"""
|
||||
|
||||
self.indices = self._balance(
|
||||
dataset,
|
||||
by_attr,
|
||||
@@ -551,9 +583,29 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
shuffle_strata: bool = True,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Note: behavior is a little odd for nested behavior; not exactly
|
||||
perfectly uniform throughout. This is a little difficult: you can't
|
||||
exactly know ahead of time the size of the subgroups across splits
|
||||
Recursive balancing of items by attribute.
|
||||
|
||||
.. note::
|
||||
|
||||
Behavior is a little odd for nested behavior; not exactly perfectly
|
||||
uniform throughout. This is a little difficult: you can't exactly
|
||||
know ahead of time the size of the subgroups across splits
|
||||
|
||||
Parameters:
|
||||
dataset: dataset to split, defaults to ``self``. Facilitates
|
||||
recursive splitting when multi-attribute splits are needed.
|
||||
by_attr: attribute or attributes to use when grouping strata for
|
||||
dataset splits. Defaults to ``None``, which will use item
|
||||
indices.
|
||||
split_min_sizes: minimum allowed sizes of splits. Must have the
|
||||
same length as ``by_attr``.
|
||||
split_max_sizes: maximum allowed sizes of splits. Must have the
|
||||
same length as ``by_attr``.
|
||||
shuffle_strata: shuffle the strata order before split is drawn. We
|
||||
parameterize this because a Dataloader-level shuffle operation
|
||||
will only change the order of the indices in the resulting
|
||||
splits; only a shuffle of the strata order can change the
|
||||
actual content of the splits themselves.
|
||||
"""
|
||||
|
||||
if by_attr == []:
|
||||
@@ -643,6 +695,9 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
dataset. The underlying data remain the same, but when indices get set,
|
||||
you're effectively applying a mask over any existing indices, always
|
||||
operating *relative* to the existing mask.
|
||||
|
||||
Parameters:
|
||||
indices: list of indices to set
|
||||
"""
|
||||
|
||||
# manually set new size
|
||||
@@ -670,6 +725,13 @@ class BatchedDataset[U, R, I](Dataset):
|
||||
return self._dataset_len
|
||||
|
||||
def __getitem__(self, index: int) -> I:
|
||||
"""
|
||||
Get the dataset item at the specified index.
|
||||
|
||||
Parameters:
|
||||
index: index of item to retrieve
|
||||
"""
|
||||
|
||||
item_data = self._get_item(index)
|
||||
index = self.indices[index]
|
||||
|
||||
@@ -691,9 +753,10 @@ class CompositeBatchedDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||
"""
|
||||
Dataset class for wrapping individual datasets.
|
||||
|
||||
Note: because this remains a valid ``BatchedDataset``, we re-thread the
|
||||
generic type variables through the set of composed datasets. That is, they
|
||||
must have a common domain type ``Domain[U, R]``.
|
||||
.. note::
|
||||
Because this remains a valid ``BatchedDataset``, we re-thread the
|
||||
generic type variables through the set of composed datasets. That is,
|
||||
they must have a common domain type ``Domain[U, R]``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -878,7 +941,7 @@ class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||
|
||||
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
||||
"""
|
||||
Batched dataset where batches have arbitrary size.
|
||||
Batched dataset where batches may have arbitrary size.
|
||||
|
||||
Methods left for inheriting classes:
|
||||
|
||||
|
||||
@@ -12,39 +12,38 @@ class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
|
||||
"""
|
||||
The following line is to satisfy the type checker, which
|
||||
|
||||
1. Can't recognize an appropriately re-typed constructor arg like
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: DiskDomain,
|
||||
...
|
||||
): ...
|
||||
1. Can't recognize an appropriately re-typed constructor arg like::
|
||||
|
||||
This *does* match the parent generic for the U=Path, R=bytes context
|
||||
def __init__(
|
||||
self,
|
||||
domain: DiskDomain,
|
||||
...
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: Domain[U, R],
|
||||
...
|
||||
): ...
|
||||
This *does* match the parent generic for the ``U=Path``, ``R=bytes``
|
||||
context::
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
domain: Domain[U, R],
|
||||
...
|
||||
): ...
|
||||
|
||||
but the type checker doesn't see this.
|
||||
|
||||
2. "Lifted" type variables out of generics can't be used as upper bounds,
|
||||
at least not without throwing type checker warnings (thanks to PEP695).
|
||||
So I'm not allowed to have
|
||||
So I'm not allowed to have::
|
||||
|
||||
```
|
||||
class BatchedDataset[U, R, D: Domain[U, R]]:
|
||||
...
|
||||
```
|
||||
class BatchedDataset[U, R, D: Domain[U, R]]:
|
||||
...
|
||||
|
||||
which could bring appropriately dynamic typing for ``Domain``s, but is
|
||||
which could bring appropriately dynamic typing for ``Domains``, but is
|
||||
not a sufficiently concrete upper bound.
|
||||
|
||||
So: we settle for a class-level type declaration, which despite not being
|
||||
technically appropriately scoped, it's not harming anything and satisfies
|
||||
``ty`` type checks downstream (e.g., when we access ``DiskDomain.root``.
|
||||
``ty`` type checks downstream (e.g., when we access ``DiskDomain.root``).
|
||||
"""
|
||||
|
||||
domain: DiskDomain
|
||||
|
||||
@@ -1,24 +1,5 @@
|
||||
"""
|
||||
Defines a knowledge domain. Wraps a Dataset / Simulator / Knowledge
|
||||
|
||||
Downstream exploration might include
|
||||
|
||||
- Calibrating Simulator / Knowledge with a Dataset
|
||||
- Amending Dataset with Simulator / Knowledge
|
||||
- Positioning Knowledge within Simulator context
|
||||
* Where to replace Simulator subsystem with Knowledge?
|
||||
|
||||
Other variations:
|
||||
|
||||
- Multi-fidelity simulators
|
||||
- Multi-scale models
|
||||
- Multi-system
|
||||
- Incomplete knowledge / divergence among sources
|
||||
|
||||
Questions:
|
||||
|
||||
- Should Simulator / Knowledge be unified as one (e.g., "Expert")
|
||||
|
||||
Generic URI-resource mapping structure
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Iterator, Sequence
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""
|
||||
Base class for trainable models
|
||||
|
||||
Development note
|
||||
|
||||
I'd rather lay out bare args and kwargs in the estimator methods, but the
|
||||
@@ -24,7 +26,7 @@ from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from trainlib.util.type import OptimizerKwargs
|
||||
from trainlib.utils.type import OptimizerKwargs
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -107,8 +107,14 @@ class LSTM[K: RNNKwargs](Estimator[K]):
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
predictions = self(**kwargs)[0]
|
||||
labels = kwargs["labels"]
|
||||
mae = F.l1_loss(predictions, labels).item()
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"mse": loss,
|
||||
"mae": mae,
|
||||
"grad_norm": get_grad_norm(self)
|
||||
}
|
||||
|
||||
@@ -291,7 +297,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
logger.info(f"| > {self.output_dim=}")
|
||||
|
||||
|
||||
class ConvRNN[K: RNNKwargs](Estimator[K]):
|
||||
class ConvGRU[K: RNNKwargs](Estimator[K]):
|
||||
"""
|
||||
Base recurrent convolutional architecture.
|
||||
|
||||
@@ -441,11 +447,18 @@ class ConvRNN[K: RNNKwargs](Estimator[K]):
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
predictions = self(**kwargs)[0].squeeze(-1)
|
||||
labels = kwargs["labels"]
|
||||
mae = F.l1_loss(predictions, labels).item()
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"mse": loss,
|
||||
"mae": mae,
|
||||
"grad_norm": get_grad_norm(self)
|
||||
}
|
||||
|
||||
|
||||
def optimizers(
|
||||
self,
|
||||
**kwargs: Unpack[OptimizerKwargs],
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
Core interface for training ``Estimators`` with ``Datasets``
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
@@ -11,6 +15,7 @@ from collections.abc import Callable
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch import cuda, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@@ -30,7 +35,15 @@ logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
class Trainer[I, K: EstimatorKwargs]:
|
||||
"""
|
||||
Training interface for updating ``Estimators`` with ``Datasets``.
|
||||
Training interface for optimizing parameters of ``Estimators`` with
|
||||
``Datasets``.
|
||||
|
||||
This class is generic to a dataset item type ``I`` and an estimator kwarg
|
||||
type ``K``. These are the two primary components ``Trainer`` objects need
|
||||
to coordinate: they ultimately rely on a provided map to ensure data items
|
||||
(type ``I``) from a dataset are appropriately routed as inputs to key
|
||||
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
|
||||
of type ``K``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -42,8 +55,10 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
estimator: `Estimator` model object
|
||||
estimator: ``Estimator`` model object
|
||||
device: device on which to carry out training
|
||||
chkpt_dir: directory to write model checkpoints
|
||||
tblog_dir: directory to write TensorBoard logs
|
||||
"""
|
||||
|
||||
self.device: str
|
||||
@@ -86,7 +101,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Set base tracking parameters.
|
||||
Set initial tracking parameters for the primary training loop.
|
||||
"""
|
||||
|
||||
self._step: int = 0
|
||||
@@ -98,6 +113,151 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
self._stagnant_epochs = 0
|
||||
self._best_model_state_dict: dict[str, Any] = {}
|
||||
|
||||
def _train_epoch(
|
||||
self,
|
||||
train_loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
optimizers: tuple[Optimizer, ...],
|
||||
writer: SummaryWriter,
|
||||
max_grad_norm: float | None = None,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Train the estimator for a single epoch.
|
||||
"""
|
||||
|
||||
train_loss_sums = []
|
||||
self.estimator.train()
|
||||
with tqdm(train_loader, unit="batch") as train_epoch:
|
||||
for i, batch_data in enumerate(train_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
|
||||
# one-time logging
|
||||
if self._step == 0:
|
||||
writer.add_graph(
|
||||
ModelWrapper(self.estimator),
|
||||
est_kwargs
|
||||
)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=False,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
train_losses = self.estimator.loss(**est_kwargs)
|
||||
train_loss_items = []
|
||||
for o_idx, optimizer in enumerate(optimizers):
|
||||
optimizer.zero_grad()
|
||||
train_loss = next(train_losses)
|
||||
|
||||
if len(train_loss_sums) <= o_idx:
|
||||
train_loss_sums.append(0.0)
|
||||
|
||||
train_loss_item = train_loss.item()
|
||||
train_loss_sums[o_idx] += train_loss_item
|
||||
train_loss_items.append(train_loss_item)
|
||||
|
||||
train_loss.backward()
|
||||
|
||||
# clip gradients for optimizer's parameters
|
||||
if max_grad_norm is not None:
|
||||
clip_grad_norm_(
|
||||
self._get_optimizer_parameters(optimizer),
|
||||
max_norm=max_grad_norm
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
self._step += len(inputs)
|
||||
|
||||
for train_loss_item, train_loss_sum in zip(
|
||||
train_loss_items,
|
||||
train_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("train_loss", train_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(
|
||||
f"train_{metric_name}",
|
||||
metric_value
|
||||
)
|
||||
|
||||
self.estimator.epoch_step()
|
||||
|
||||
for li, train_loss_sum in enumerate(train_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
||||
)
|
||||
|
||||
return train_loss_sums
|
||||
|
||||
def _val_epoch(
|
||||
self,
|
||||
val_loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
optimizers: tuple[Optimizer, ...],
|
||||
writer: SummaryWriter,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Perform and record validation scores for a single epoch.
|
||||
"""
|
||||
|
||||
val_loss_sums = []
|
||||
self.estimator.eval()
|
||||
with tqdm(val_loader, unit="batch") as val_epoch:
|
||||
for i, batch_data in enumerate(val_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=True,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
val_losses = self.estimator.loss(**est_kwargs)
|
||||
val_loss_items = []
|
||||
for o_idx in range(len(optimizers)):
|
||||
val_loss = next(val_losses)
|
||||
|
||||
if len(val_loss_sums) <= o_idx:
|
||||
val_loss_sums.append(0.0)
|
||||
|
||||
val_loss_item = val_loss.item()
|
||||
val_loss_sums[o_idx] += val_loss_item
|
||||
val_loss_items.append(val_loss_item)
|
||||
|
||||
for val_loss_item, val_loss_sum in zip(
|
||||
val_loss_items,
|
||||
val_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("val_loss", val_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"val_{metric_name}", metric_value)
|
||||
|
||||
for li, val_loss_sum in enumerate(val_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
||||
)
|
||||
|
||||
# convergence of multiple losses may be ambiguous
|
||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||
|
||||
return val_loss_sums
|
||||
|
||||
def train(
|
||||
self,
|
||||
dataset: BatchedDataset[..., ..., I],
|
||||
@@ -122,44 +282,54 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
"""
|
||||
Note: this method attempts to implement a general scheme for passing
|
||||
needed items to the estimator's loss function from the dataloader. The
|
||||
abstract `Estimator` base only requires the model output be provided
|
||||
abstract ``Estimator`` base only requires the model output be provided
|
||||
for any given loss calculation, but concrete estimators will often
|
||||
require additional arguments (e.g., labels or length masks, as
|
||||
is the case with sequential models). In any case, this method defers
|
||||
any further logic to the `loss` method of the underlying estimator, so
|
||||
require additional arguments (e.g., labels or length masks, as is the
|
||||
case with sequential models). In any case, this method defers any
|
||||
further logic to the ``loss`` method of the underlying estimator, so
|
||||
one should take care to synchronize the sample structure with `dataset`
|
||||
to match that expected by `self.estimator.loss(...)`.
|
||||
to match that expected by ``self.estimator.loss(...)``.
|
||||
|
||||
.. admonition:: On ``batch_estimator_map``
|
||||
|
||||
On batch_estimator_map:
|
||||
Dataloader collate functions are responsible for mapping a
|
||||
collection of items into an item of collections, roughly speaking.
|
||||
If items are tuples of tensors,
|
||||
|
||||
Dataloader collate functions are responsible for mapping a collection
|
||||
of items into an item of collections, roughly speaking. If items are
|
||||
tuples of tensors,
|
||||
.. code-block:: text
|
||||
|
||||
[
|
||||
( [1, 1], [1, 1] ),
|
||||
( [2, 2], [2, 2] ),
|
||||
( [3, 3], [3, 3] ),
|
||||
]
|
||||
[
|
||||
( [1, 1], [1, 1] ),
|
||||
( [2, 2], [2, 2] ),
|
||||
( [3, 3], [3, 3] ),
|
||||
]
|
||||
|
||||
the collate function maps back into the item skeleton, producing a
|
||||
single tuple of (stacked) tensors
|
||||
the collate function maps back into the item skeleton, producing a
|
||||
single tuple of (stacked) tensors
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
( [[1, 1],
|
||||
[2, 2],
|
||||
[3, 3]],
|
||||
( [[1, 1],
|
||||
[2, 2],
|
||||
[3, 3]],
|
||||
|
||||
[[1, 1],
|
||||
[2, 2],
|
||||
[3, 3]] )
|
||||
[[1, 1],
|
||||
[2, 2],
|
||||
[3, 3]] )
|
||||
|
||||
This function should map from batches (which should be *item shaped*,
|
||||
i.e., have an `I` skeleton, even if stacked items may be different on
|
||||
the inside) into estimator keyword arguments (type `K`).
|
||||
This function should map from batches (which should be *item
|
||||
shaped*, i.e., have an ``I`` skeleton, even if stacked items may be
|
||||
different on the inside) into estimator keyword arguments (type
|
||||
``K``).
|
||||
|
||||
Parameters:
|
||||
dataset: dataset to train the estimator
|
||||
batch_estimator_map: function mapping from batch data to expected
|
||||
estimator kwargs
|
||||
lr: learning rate (default: 1e-3)
|
||||
eps: adam EPS (default: 1e-8)
|
||||
max_grad_norm: upper bound to use when clipping gradients. If left
|
||||
as ``None``, no gradient clipping is performed.
|
||||
max_epochs: maximum number of training epochs
|
||||
stop_after_epochs: number of epochs with stagnant validation losses
|
||||
to allow before early stopping. If training stops earlier, the
|
||||
@@ -212,132 +382,39 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
while self._epoch <= max_epochs and not self._converged(
|
||||
self._epoch, stop_after_epochs
|
||||
):
|
||||
print(f"Training epoch {self._epoch}/{max_epochs}...")
|
||||
print(f"Stagnant epochs {self._stagnant_epochs}/{stop_after_epochs}...")
|
||||
train_frac = f"{self._epoch}/{max_epochs}"
|
||||
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||
print(f"Training epoch {train_frac}...")
|
||||
print(f"Stagnant epochs {stag_frac}...")
|
||||
|
||||
epoch_start_time = time.time()
|
||||
train_loss_sums = []
|
||||
self.estimator.train()
|
||||
with tqdm(train_loader, unit="batch") as train_epoch:
|
||||
for i, batch_data in enumerate(train_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
|
||||
# one-time logging
|
||||
if self._step == 0:
|
||||
writer.add_graph(ModelWrapper(self.estimator), est_kwargs)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=False,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
train_losses = self.estimator.loss(**est_kwargs)
|
||||
train_loss_items = []
|
||||
for o_idx, optimizer in enumerate(optimizers):
|
||||
optimizer.zero_grad()
|
||||
train_loss = next(train_losses)
|
||||
|
||||
if len(train_loss_sums) <= o_idx:
|
||||
train_loss_sums.append(0.0)
|
||||
|
||||
train_loss_item = train_loss.item()
|
||||
train_loss_sums[o_idx] += train_loss_item
|
||||
train_loss_items.append(train_loss_item)
|
||||
|
||||
train_loss.backward()
|
||||
|
||||
# clip gradients for optimizer's parameters
|
||||
if max_grad_norm is not None:
|
||||
opt_params = self._get_optimizer_parameters(optimizer)
|
||||
clip_grad_norm_(opt_params, max_norm=max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
self._step += len(inputs)
|
||||
|
||||
for train_loss_item, train_loss_sum in zip(
|
||||
train_loss_items,
|
||||
train_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
train_epoch.set_postfix(loss=f"{train_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("train_loss", train_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"train_{metric_name}", metric_value)
|
||||
|
||||
self.estimator.epoch_step()
|
||||
|
||||
for li, train_loss_sum in enumerate(train_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"train_loss{li}_epoch", train_loss_sum / len(train_loader)
|
||||
)
|
||||
self._train_epoch(
|
||||
train_loader,
|
||||
batch_estimator_map,
|
||||
optimizers,
|
||||
writer,
|
||||
max_grad_norm
|
||||
)
|
||||
|
||||
if val_frac > 0:
|
||||
val_loss_sums = []
|
||||
self.estimator.eval()
|
||||
with tqdm(val_loader, unit="batch") as val_epoch:
|
||||
for i, batch_data in enumerate(val_epoch):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
inputs = est_kwargs["inputs"]
|
||||
self._val_epoch(
|
||||
val_loader,
|
||||
batch_estimator_map,
|
||||
optimizers,
|
||||
writer,
|
||||
)
|
||||
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=True,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
val_losses = self.estimator.loss(**est_kwargs)
|
||||
val_loss_items = []
|
||||
for o_idx in range(len(optimizers)):
|
||||
val_loss = next(val_losses)
|
||||
|
||||
if len(val_loss_sums) <= o_idx:
|
||||
val_loss_sums.append(0.0)
|
||||
|
||||
val_loss_item = val_loss.item()
|
||||
val_loss_sums[o_idx] += val_loss_item
|
||||
val_loss_items.append(val_loss_item)
|
||||
|
||||
for val_loss_item, val_loss_sum in zip(
|
||||
val_loss_items,
|
||||
val_loss_sums,
|
||||
strict=True,
|
||||
):
|
||||
val_epoch.set_postfix(loss=f"{val_loss_sum/(i+1):8.2f}")
|
||||
self._add_summary_item("val_loss", val_loss_item)
|
||||
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(f"val_{metric_name}", metric_value)
|
||||
|
||||
for li, val_loss_sum in enumerate(val_loss_sums):
|
||||
self._add_summary_item(
|
||||
f"val_loss{li}_epoch", val_loss_sum / len(val_loader)
|
||||
)
|
||||
|
||||
# convergence of multiple losses may be ambiguous
|
||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||
|
||||
self._add_summary_item("epoch_time_sec", time.time() - epoch_start_time)
|
||||
self._add_summary_item(
|
||||
"epoch_time_sec",
|
||||
time.time() - epoch_start_time
|
||||
)
|
||||
|
||||
if self._epoch % summarize_every == 0:
|
||||
self._summarize(writer, self._epoch)
|
||||
|
||||
# save checkpoint
|
||||
if self._epoch % chkpt_every == 0:
|
||||
self.save_model(
|
||||
self._epoch, self.chkpt_dir, dir_prefix
|
||||
)
|
||||
self.save_model(self._epoch, self.chkpt_dir, dir_prefix)
|
||||
|
||||
self._epoch += 1
|
||||
|
||||
@@ -431,7 +508,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
def _summarize(self, writer: SummaryWriter, epoch: int) -> None:
|
||||
"""
|
||||
Flush the training summary to the TB summary writer.
|
||||
Flush the training summary to the TensorBoard summary writer.
|
||||
"""
|
||||
|
||||
summary_values = defaultdict(list)
|
||||
@@ -485,17 +562,18 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
chkpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
chkpt_path.write_bytes(model_buff.getvalue())
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
epoch: int,
|
||||
chkpt_dir: str,
|
||||
) -> None:
|
||||
def load_model(self, epoch: int, chkpt_dir: str) -> None:
|
||||
"""
|
||||
Load a model checkpoint from a given epoch.
|
||||
|
||||
Note that this assumes the model was saved via `Trainer.save_model()`,
|
||||
and the estimator provided to this `Trainer` instance matches the
|
||||
architecture of the checkpoint model being loaded.
|
||||
Note that this assumes the model was saved via
|
||||
``Trainer.save_model()``, and the estimator provided to this
|
||||
``Trainer`` instance matches the architecture of the checkpoint model
|
||||
being loaded.
|
||||
|
||||
Parameters:
|
||||
epoch: epoch of saved model
|
||||
chkpt_dir:
|
||||
"""
|
||||
|
||||
model_class = self.estimator.__class__.__name__
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
Transform base for dataset records
|
||||
"""
|
||||
|
||||
class Transform[I]:
|
||||
"""
|
||||
Dataset transform base class.
|
||||
@@ -8,4 +12,14 @@ class Transform[I]:
|
||||
"""
|
||||
|
||||
def __call__(self, item: I) -> I:
|
||||
"""
|
||||
Apply transform to item.
|
||||
|
||||
Parameters:
|
||||
item: item object to transform
|
||||
|
||||
Returns:
|
||||
transformed item (same type ``I`` as input)
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
46
trainlib/utils/custom.mplstyle
Normal file
46
trainlib/utils/custom.mplstyle
Normal file
@@ -0,0 +1,46 @@
|
||||
text.usetex : False
|
||||
mathtext.default : regular
|
||||
|
||||
font.family : sans-serif
|
||||
font.sans-serif : DejaVu Sans
|
||||
font.serif : DejaVu Serif
|
||||
font.cursive : DejaVu Sans
|
||||
mathtext.fontset : dejavuserif
|
||||
font.size : 9
|
||||
figure.titlesize : 9
|
||||
legend.fontsize : 9
|
||||
axes.titlesize : 9
|
||||
axes.labelsize : 9
|
||||
xtick.labelsize : 9
|
||||
ytick.labelsize : 9
|
||||
|
||||
#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb'])
|
||||
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
|
||||
|
||||
image.interpolation : nearest
|
||||
image.resample : False
|
||||
image.composite_image : True
|
||||
|
||||
axes.spines.left : True
|
||||
axes.spines.bottom : True
|
||||
axes.spines.top : False
|
||||
axes.spines.right : False
|
||||
|
||||
axes.linewidth : 1
|
||||
xtick.major.width : 1
|
||||
xtick.minor.width : 1
|
||||
ytick.major.width : 1
|
||||
ytick.minor.width : 1
|
||||
|
||||
lines.linewidth : 1
|
||||
lines.markersize : 1
|
||||
|
||||
savefig.dpi : 300
|
||||
savefig.format : svg
|
||||
savefig.bbox : tight
|
||||
savefig.pad_inches : 0.1
|
||||
|
||||
svg.image_inline : True
|
||||
svg.fonttype : none
|
||||
|
||||
legend.frameon : False
|
||||
38
trainlib/utils/plot.py
Normal file
38
trainlib/utils/plot.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
FILE = Path(__file__).parent.absolute()
|
||||
|
||||
|
||||
class use_style:
|
||||
def __init__(
|
||||
self,
|
||||
style: list[str] | None = None,
|
||||
**kwargs: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if style is None:
|
||||
style = [str(Path(FILE, "custom.mplstyle"))]
|
||||
self.style = style + [kwargs]
|
||||
self.previous_style = {}
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.previous_style = mpl.rcParams.copy()
|
||||
if self.style is not None:
|
||||
plt.style.use(self.style)
|
||||
|
||||
def __exit__(self, *args: str, **kwargs: str) -> None:
|
||||
mpl.rcParams.update(self.previous_style)
|
||||
|
||||
|
||||
def set_style(
|
||||
style: list[str] | None = None,
|
||||
**kwargs: str,
|
||||
) -> None:
|
||||
if style is None:
|
||||
style = [str(Path(FILE, "custom.mplstyle"))]
|
||||
|
||||
plt.style.use(style + [kwargs])
|
||||
21
trainlib/utils/session.py
Normal file
21
trainlib/utils/session.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def seed_all_backends(seed: int | Tensor | None = None) -> None:
|
||||
"""Sets all python, numpy and pytorch seeds."""
|
||||
|
||||
if seed is None:
|
||||
seed = int(torch.randint(1000000, size=(1,)))
|
||||
else:
|
||||
seed = int(seed)
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
100
uv.lock
generated
100
uv.lock
generated
@@ -248,9 +248,13 @@ dependencies = [
|
||||
{ name = "cuda-pathfinder" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/05/8b/b4b2d1c7775fa403b64333e720cfcfccef8dcb9cdeb99947061ca5a77628/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf8bfaedc238f3b115d957d1fd6562b7e8435ba57f6d0e2f87d0e7149ccb2da5", size = 11570071, upload-time = "2025-10-21T14:51:47.472Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/07/6aff13bc1e977e35aaa6b22f52b172e2890c608c6db22438cf7ed2bf43a6/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3adf4958dcf68ae7801a59b73fb00a8b37f8d0595060d66ceae111b1002de38d", size = 11566797, upload-time = "2025-10-21T14:51:54.581Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/b5/96a6696e20c4ffd2b327f54c7d0fde2259bdb998d045c25d5dedbbe30290/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f53a7f453d4b2643d8663d036bafe29b5ba89eb904c133180f295df6dc151e5", size = 11624530, upload-time = "2025-10-21T14:52:01.539Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/af/6dfd8f2ed90b1d4719bc053ff8940e494640fe4212dc3dd72f383e4992da/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686", size = 11922703, upload-time = "2025-10-21T14:52:03.585Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/73/d2fc40c043bac699c3880bf88d3cebe9d88410cd043795382826c93a89f0/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:20f2699d61d724de3eb3f3369d57e2b245f93085cab44fd37c3bea036cea1a6f", size = 11565056, upload-time = "2025-10-21T14:52:08.338Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/19/90ac264acc00f6df8a49378eedec9fd2db3061bf9263bf9f39fd3d8377c3/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee", size = 11924658, upload-time = "2025-10-21T14:52:10.411Z" },
|
||||
]
|
||||
|
||||
@@ -861,6 +865,7 @@ name = "nvidia-cublas-cu12"
|
||||
version = "12.8.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" },
|
||||
]
|
||||
|
||||
@@ -869,6 +874,7 @@ name = "nvidia-cuda-cupti-cu12"
|
||||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" },
|
||||
]
|
||||
|
||||
@@ -878,6 +884,7 @@ version = "12.8.93"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -885,6 +892,7 @@ name = "nvidia-cuda-runtime-cu12"
|
||||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" },
|
||||
]
|
||||
|
||||
@@ -896,6 +904,7 @@ dependencies = [
|
||||
{ name = "nvidia-cublas-cu12" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
|
||||
]
|
||||
|
||||
@@ -907,6 +916,7 @@ dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
|
||||
]
|
||||
|
||||
@@ -916,6 +926,7 @@ version = "1.13.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -923,6 +934,7 @@ name = "nvidia-curand-cu12"
|
||||
version = "10.3.9.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" },
|
||||
]
|
||||
|
||||
@@ -936,6 +948,7 @@ dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
|
||||
]
|
||||
|
||||
@@ -947,6 +960,7 @@ dependencies = [
|
||||
{ name = "nvidia-nvjitlink-cu12" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
|
||||
]
|
||||
|
||||
@@ -955,6 +969,7 @@ name = "nvidia-cusparselt-cu12"
|
||||
version = "0.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" },
|
||||
]
|
||||
|
||||
@@ -963,6 +978,7 @@ name = "nvidia-nccl-cu12"
|
||||
version = "2.27.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/1c/857979db0ef194ca5e21478a0612bcdbbe59458d7694361882279947b349/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:31432ad4d1fb1004eb0c56203dc9bc2178a1ba69d1d9e02d64a6938ab5e40e7a", size = 322400625, upload-time = "2025-06-26T04:11:04.496Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" },
|
||||
]
|
||||
|
||||
@@ -972,6 +988,7 @@ version = "12.8.93"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -979,6 +996,7 @@ name = "nvidia-nvshmem-cu12"
|
||||
version = "3.4.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/6a/03aa43cc9bd3ad91553a88b5f6fb25ed6a3752ae86ce2180221962bc2aa5/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15", size = 138936938, upload-time = "2025-09-06T00:32:05.589Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" },
|
||||
]
|
||||
|
||||
@@ -987,6 +1005,7 @@ name = "nvidia-nvtx-cu12"
|
||||
version = "12.8.90"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
|
||||
]
|
||||
|
||||
@@ -1393,14 +1412,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "sphinx-autodoc-typehints"
|
||||
version = "3.9.5"
|
||||
version = "3.9.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "sphinx" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/ec/21bd9babcfeb9930a73011257002d5cfa5fd30667b8de6d76dbaf8275dfb/sphinx_autodoc_typehints-3.9.5.tar.gz", hash = "sha256:60e646efb7c352a0e98f34dd7fdcde4527fbbdbdf30371ff8321b6b3eb1fd37d", size = 63249, upload-time = "2026-03-02T19:58:07.974Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/06/da2d9e98b3f7f0df144496e62f453e0025f129bccc7a6076b8ceae6047b1/sphinx_autodoc_typehints-3.9.7.tar.gz", hash = "sha256:70f3dd4e4dd815ae30e5d3848a26dca71fb5e7fcf8f37cf8b840dc8afdf07e82", size = 68689, upload-time = "2026-03-05T18:33:40.829Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/cb/80c250f47a0ca5ac67d82f14811b4068a551a12b4790b085ffdb900de427/sphinx_autodoc_typehints-3.9.5-py3-none-any.whl", hash = "sha256:c94f88a90b6c61a7a6686cb77b410e46e077712838387e6cf22d69e85cfd06a5", size = 34763, upload-time = "2026-03-02T19:58:06.028Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/a0/e7d3365dabfa79a1b2ac7d3122b5b22b401a9c4d5e4eadc5e13b88c63a2c/sphinx_autodoc_typehints-3.9.7-py3-none-any.whl", hash = "sha256:dd73f6a32adef0d8208f6f7d99254e1880259c77db7b4a91648345d45202d48e", size = 36691, upload-time = "2026-03-05T18:33:38.983Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1542,52 +1561,47 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
version = "2.10.0+cu128"
|
||||
source = { registry = "https://download.pytorch.org/whl/cu128" }
|
||||
dependencies = [
|
||||
{ name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "cuda-bindings", marker = "sys_platform == 'linux'" },
|
||||
{ name = "filelock" },
|
||||
{ name = "fsspec" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "networkx" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-curand-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux'" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "sympy" },
|
||||
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
{ name = "triton", marker = "sys_platform == 'linux'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/93/716b5ac0155f1be70ed81bacc21269c3ece8dba0c249b9994094110bfc51/torch-2.10.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:bf0d9ff448b0218e0433aeb198805192346c4fd659c852370d5cc245f602a06a", size = 79464992, upload-time = "2026-01-21T16:23:05.162Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/2b/51e663ff190c9d16d4a8271203b71bc73a16aa7619b9f271a69b9d4a936b/torch-2.10.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:233aed0659a2503b831d8a67e9da66a62c996204c0bba4f4c442ccc0c68a3f60", size = 146018567, upload-time = "2026-01-21T16:22:23.393Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/cd/4b95ef7f293b927c283db0b136c42be91c8ec6845c44de0238c8c23bdc80/torch-2.10.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:682497e16bdfa6efeec8cde66531bc8d1fbbbb4d8788ec6173c089ed3cc2bfe5", size = 915721646, upload-time = "2026-01-21T16:21:16.983Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/97/078a007208f8056d88ae43198833469e61a0a355abc0b070edd2c085eb9a/torch-2.10.0-cp314-cp314-win_amd64.whl", hash = "sha256:6528f13d2a8593a1a412ea07a99812495bec07e9224c28b2a25c0a30c7da025c", size = 113752373, upload-time = "2026-01-21T16:22:13.471Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/94/71994e7d0d5238393df9732fdab607e37e2b56d26a746cb59fdb415f8966/torch-2.10.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:f5ab4ba32383061be0fb74bda772d470140a12c1c3b58a0cfbf3dae94d164c28", size = 79850324, upload-time = "2026-01-21T16:22:09.494Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/65/1a05346b418ea8ccd10360eef4b3e0ce688fba544e76edec26913a8d0ee0/torch-2.10.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:716b01a176c2a5659c98f6b01bf868244abdd896526f1c692712ab36dbaf9b63", size = 146006482, upload-time = "2026-01-21T16:22:18.42Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/b9/5f6f9d9e859fc3235f60578fa64f52c9c6e9b4327f0fe0defb6de5c0de31/torch-2.10.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:d8f5912ba938233f86361e891789595ff35ca4b4e2ac8fe3670895e5976731d6", size = 915613050, upload-time = "2026-01-21T16:20:49.035Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/4d/35352043ee0eaffdeff154fad67cd4a31dbed7ff8e3be1cc4549717d6d51/torch-2.10.0-cp314-cp314t-win_amd64.whl", hash = "sha256:71283a373f0ee2c89e0f0d5f446039bdabe8dbc3c9ccf35f0f784908b0acd185", size = 113995816, upload-time = "2026-01-21T16:22:05.312Z" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:bdbcc703382f948e951c063448c9406bf38ce66c41dd698d9e2733fcf96c037a" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7b4bd23ed63de97456fcc81c26fea9f02ee02ce1112111c4dac0d8cfe574b23e" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:4d1b0b49c54223c7c04050b49eac141d77b6edbc34aea1dfc74a6fdb661baa8c" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f1f8b840c64b645a4bc61a393db48effb9c92b2dc26c8373873911f0750d1ea7" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:23f58258012bcf1c349cb22af387e33aadca7f83ea617b080e774eb41e4fe8ff" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:01b216e097b17a5277cfb47c383cdcacf06abeadcb0daca0c76b59e72854c3b6" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:c42377bc2607e3e1c60da71b792fb507c3938c87fd6edab8b21c59c91473c36d" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:37d71feea068776855686a1512058df3f19f6f040a151f055aa746601678744f" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:c57017ca29e62271e362fdeee7d20070e254755a5148b30b553d8a10fc83c7ef" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:777461f50b2daf77e4bdd8e2ad34bdfc5a993bf1bdf2ab9ef39f5edfe4e9c12b" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7bcba6a7c5f0987a13298b1ca843155dcceceac758fa3c7ccd5c7af4059a1080" },
|
||||
{ url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:70d89143c956389d4806cb4e5fe0b1129fe0db280e1073288d17fa76c101cba4" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1623,7 +1637,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "trainlib"
|
||||
version = "0.1.0"
|
||||
version = "0.1.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "colorama" },
|
||||
@@ -1662,7 +1676,7 @@ requires-dist = [
|
||||
{ name = "sphinx-autodoc-typehints", marker = "extra == 'doc'" },
|
||||
{ name = "sphinx-togglebutton", marker = "extra == 'doc'" },
|
||||
{ name = "tensorboard", specifier = ">=2.20.0" },
|
||||
{ name = "torch", specifier = ">=2.5.1" },
|
||||
{ name = "torch", index = "https://download.pytorch.org/whl/cu128" },
|
||||
{ name = "tqdm", specifier = ">=4.67.1" },
|
||||
]
|
||||
provides-extras = ["dev", "doc", "test"]
|
||||
@@ -1681,9 +1695,13 @@ name = "triton"
|
||||
version = "3.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" },
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user