add plot styles, clean up package-wide docstrings

This commit is contained in:
2026-03-07 03:10:13 -08:00
parent faeef9c72a
commit e867bc0e7f
11 changed files with 398 additions and 119 deletions

View File

@@ -6,8 +6,8 @@
# -- Project information ------------------------------------------------------ # -- Project information ------------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "<package-name>" project = "trainlib"
copyright = "2025, Sam Griesemer" copyright = "2026, Sam Griesemer"
author = "Sam Griesemer" author = "Sam Griesemer"
# -- General configuration ---------------------------------------------------- # -- General configuration ----------------------------------------------------
@@ -44,6 +44,7 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
html_theme = "furo" html_theme = "furo"
html_static_path = ["_static"] html_static_path = ["_static"]
# html_sidebars = { # html_sidebars = {
# '**': ['/modules.html'], # '**': ['/modules.html'],
# } # }

View File

@@ -1,4 +1,4 @@
# `<project-name>` package docs # `trainlib` package docs
{ref}`genindex` {ref}`genindex`
{ref}`modindex` {ref}`modindex`
{ref}`search` {ref}`search`
@@ -14,7 +14,7 @@
:maxdepth: 3 :maxdepth: 3
:caption: Autoref :caption: Autoref
_autoref/<project-name>.rst _autoref/index.rst
``` ```
```{toctree} ```{toctree}

View File

@@ -24,11 +24,11 @@ classifiers = [
"Intended Audience :: End Users/Desktop", "Intended Audience :: End Users/Desktop",
] ]
dependencies = [ dependencies = [
"torch",
"colorama>=0.4.6", "colorama>=0.4.6",
"matplotlib>=3.10.8", "matplotlib>=3.10.8",
"numpy>=2.4.1", "numpy>=2.4.1",
"tensorboard>=2.20.0", "tensorboard>=2.20.0",
"torch>=2.5.1",
"tqdm>=4.67.1", "tqdm>=4.67.1",
] ]
@@ -82,3 +82,11 @@ force-sort-within-sections = false
quote-style = "double" quote-style = "double"
indent-style = "space" indent-style = "space"
docstring-code-format = true docstring-code-format = true
[tool.uv.sources]
torch = { index = "pytorch" }
[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

View File

@@ -7,12 +7,12 @@
could have a single dataset definition for a particular concrete dataset, could have a single dataset definition for a particular concrete dataset,
and so long as we're talking about the same items, it can be instantiated and so long as we're talking about the same items, it can be instantiated
using *any domain*. You wouldn't need specific subclasses for disk or using *any domain*. You wouldn't need specific subclasses for disk or
network or in-memory; you can tell it directly at runtime. network or in-memory structures; you can tell it directly at runtime.
That's an eventually possibility, anyway. As it stands, however, this is That's an eventually possibility, anyway. As it stands, however, this is
effectively impossible: effectively impossible:
You can't easily abstract the batch -> item splitting process, i.e., You can't easily abstract the batch-to-item splitting process, i.e.,
``_process_batch_data()``. A list-based version of the dataset you're ``_process_batch_data()``. A list-based version of the dataset you're
trying to define might have an individual item tuple at every index, trying to define might have an individual item tuple at every index,
whereas a disk-based version might have tuples batched across a few files. whereas a disk-based version might have tuples batched across a few files.
@@ -41,22 +41,22 @@
accepting a batch processing method upon instantiation rather than accepting a batch processing method upon instantiation rather than
structurally bolting it into the ``Dataset`` definition. This would require structurally bolting it into the ``Dataset`` definition. This would require
knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so
such a function will always have to be (I, U, R)-dependent. It nevertheless such a function will always have to be ``(I, U, R)``-dependent. It
would take out some of the pain of having to define new dataset classes; nevertheless would take out some of the pain of having to define new
instead, you'd just need to define the batch processing method. I see this dataset classes; instead, you'd just need to define the batch processing
as a worse alternative to just defining *inside* a safe context like a new method. I see this as a worse alternative to just defining *inside* a safe
dataset class: you know the types you have to respect, and you stick that context like a new dataset class: you know the types you have to respect,
method exactly in a context where it's understood. Freeing this up doesn't and you stick that method exactly in a context where it's understood.
lighten the burden of processing logic, it just changes *when* it has to be Freeing this up doesn't lighten the burden of processing logic, it just
provided, and that's not worth much (to me) in this case given the bump in changes *when* it has to be provided, and that's not worth much (to me) in
complexity. (Taking this to the extreme: you could supply *all* of an this case given the bump in complexity. (Taking this to the extreme: you
object's methods "dynamically" and glue them together at runtime so long as could supply *all* of an object's methods "dynamically" and glue them
they all played nice. But wherever you were "laying them out" beforehand is together at runtime so long as they all played nice. But wherever you were
exactly the job of a class to begin with, so you don't end up with anything "laying them out" beforehand is exactly the job of a class to begin with,
more dynamic. All we're really discussing here is pushing around so you don't end up with anything more dynamic. All we're really discussing
unavoidable complexity inside and outside of the "class walls," and in the here is pushing around unavoidable complexity inside and outside of the
particular case of ``_process_batch_data()``, it feels much better when "class walls," and in the particular case of ``_process_batch_data()``, it
it's on the inside.) feels much better when it's on the inside.)
.. admonition:: Holding area .. admonition:: Holding area
@@ -114,6 +114,7 @@
T = TypeVar("T", bound=NamedTuple) T = TypeVar("T", bound=NamedTuple)
class NamedTupleDataset[I](Dataset): class NamedTupleDataset[I](Dataset):
def __init__(self, data_list: list[I]) -> None: def __init__(self, data_list: list[I]) -> None:
self.data_list = data_list self.data_list = data_list
@@ -161,43 +162,76 @@ class BatchedDataset[U, R, I](Dataset):
The class is generic over a URI type ``U``, a resource type ``R`` (both of The class is generic over a URI type ``U``, a resource type ``R`` (both of
which are used to concretize a domain ``Domain[U, R]``), and an item type which are used to concretize a domain ``Domain[U, R]``), and an item type
``T`` (which has a ``tuple`` upper bound). ``I``.
.. admonition:: Batch and item processing flow
.. admonition:: Pipeline overview .. code-block:: text
.. code-block:: python Domain -> [U] :: self._batch_uris = list(domain)
Domain -> [U] (get _batch_uris) Grab all URIs from Domain iterators. This is made concrete
U -> R (domain access ; Rs provide batches) early to allow for Dataset sizing, and we need a Sequence
R -> [I] (cache here ; _process_batch_data to use load_transform) representation to map integer batch indices into Domains, i.e.,
[I] -> I (human item obj ; _get_item) when getting the corresponding URI:
I -> **P (final packed item ; __getitem__ to use transform)
batch_uri = self._batch_uris[batch_index]
We let Domains implement iterators over their URIs, but
explicitly exhaust when initializing Datasets.
U -> R :: batch_data = self.domain[batch_uri]
Retrieve resource from domain. Resources are viewed as batched
data, even if only wrapping single items (happens in trivial
settings).
R -> [I] :: self._process_batch_data(batch_data, batch_index)
Possibly domain-specific batch processing of resource data into
explicit Sequence-like structures of items, each of which is
subject to the provided pre_transform. Processed batches at
this stage are cached (if enabled).
[I] -> I :: self.get_batch(batch_index)[index_in_batch]
Select individual items from batches in _get_item. At this
stage, items are in intermediate states and pulled from the
cached batches.
I -> I :: self._process_item_data(item_data, index)
Produce final items with __getitem__, getting intermediate
items via _get_item and applying the provided post_transform.
.. note::
Note^1: as far as positioning, this class is meant to play nice with Note^1: as far as positioning, this class is meant to play nice with
PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The
value add for this over the ``torch.Dataset`` base is almost entirely value add for this over the ``torch.Dataset`` base is almost entirely
in the logic it implements to map out of *batched resources* that are in the logic it implements to map out of *batched resources* that are
holding data, and flattening it out into typical dataset items. There holding data, and flattening it out into typical dataset items. There
are also some QoL items when it comes to splitting and balancing are also some QoL features when it comes to splitting and balancing
samples. samples.
Note^2: even though ``Domains`` implement iterators over their URIs, .. note::
this doesn't imply a ``BatchedDataset`` is iterable. This just means we Even though ``Domains`` implement iterators over their URIs, this
can walk over the resources that provide data, but we don't necessarily doesn't imply a ``BatchedDataset`` is iterable. This just means we
presuppose an ordered walk over samples within batches. Point being: can walk over the resources that provide data, but we don't
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate necessarily presuppose an ordered walk over samples within batches.
superclass, even when we're working around iterable ``Domains``. Point being: ``torch.Dataset``, not ``torch.IterableDataset``, is
the appropriate superclass, even when we're working around iterable
``Domains``.
Note^3: transforms are expected to operate on ``I``-items and produce .. note::
``I``-items. They shouldn't be the "introducers" of ``I`` types from Transforms are expected to operate on ``I``-items and produce
some other intermediate representation, nor should they map from ``I`` ``I``-items. They shouldn't be the "introducers" of ``I`` types
to something else. Point being: the dataset definition should be able from some other intermediate representation, nor should they map
to map resources ``R`` to ``I`` without a transform: that much should from ``I`` to something else. Point being: the dataset definition
be baked into the class definition. If you find you're expecting the should be able to map resources ``R`` to ``I`` without a transform:
transform to do that for you, you should consider pulling in some that much should be baked into the class definition. If you find
common structure across the allowed transforms and make it a fixed part you're expecting the transform to do that for you, you should
of the class. consider pulling in some common structure across the allowed
transforms and make it a fixed part of the class.
""" """
def __init__( def __init__(
@@ -211,6 +245,7 @@ class BatchedDataset[U, R, I](Dataset):
) -> None: ) -> None:
""" """
Parameters: Parameters:
domain: ``Domain`` object providing access to batched data
pre_transform: transform to apply over items during loading (in pre_transform: transform to apply over items during loading (in
``_process_batch_data()``), i.e., *before* going into ``_process_batch_data()``), i.e., *before* going into
persistent storage persistent storage
@@ -220,6 +255,7 @@ class BatchedDataset[U, R, I](Dataset):
batch_cache_limit: the max number of max batches to cache at any batch_cache_limit: the max number of max batches to cache at any
one time one time
preload: whether to load all data into memory during instantiation preload: whether to load all data into memory during instantiation
num_workers: number of workers to use when preloading data
""" """
self.domain = domain self.domain = domain
@@ -259,6 +295,9 @@ class BatchedDataset[U, R, I](Dataset):
The behavior of this method can vary depending on what we know about The behavior of this method can vary depending on what we know about
batch sizes, and should therefore be implemented by inheriting classes. batch sizes, and should therefore be implemented by inheriting classes.
Parameters:
item_index: index of item
Returns: Returns:
batch_index: int batch_index: int
index_in_batch: int index_in_batch: int
@@ -302,6 +341,10 @@ class BatchedDataset[U, R, I](Dataset):
place to use a provided ``post_transform``; items are pulled from the place to use a provided ``post_transform``; items are pulled from the
cache (if enabled) and processed before being returned as the final cache (if enabled) and processed before being returned as the final
tuple outputs (so this processing is not persistent). tuple outputs (so this processing is not persistent).
Parameters:
item_data: item data
item_index: index of item
""" """
raise NotImplementedError raise NotImplementedError
@@ -317,6 +360,9 @@ class BatchedDataset[U, R, I](Dataset):
Note that return values from `__getitem__()` are "cleaned up" versions Note that return values from `__getitem__()` are "cleaned up" versions
of this representation, with minimal info needed for training. of this representation, with minimal info needed for training.
Parameters:
item_index: index of item
""" """
if item_index >= len(self): if item_index >= len(self):
@@ -355,6 +401,9 @@ class BatchedDataset[U, R, I](Dataset):
any delayed reads here. There's no way around needing to see all batch any delayed reads here. There's no way around needing to see all batch
data at once here, and we don't want to make that ambiguous: ``list`` data at once here, and we don't want to make that ambiguous: ``list``
output type it is. output type it is.
Parameters:
batch_index: index of batch
""" """
logger.debug("Batch cache miss, reading from root...") logger.debug("Batch cache miss, reading from root...")
@@ -374,6 +423,9 @@ class BatchedDataset[U, R, I](Dataset):
Can be useful when dynamically pulling data (as it's requested) isn't Can be useful when dynamically pulling data (as it's requested) isn't
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
continually remove previous batches as they're loaded. continually remove previous batches as they're loaded.
Parameters:
num_workers: number of parallel workers to use for data loading
""" """
assert self.batch_cache_limit is None, "Preloading under cache limit" assert self.batch_cache_limit is None, "Preloading under cache limit"
@@ -431,11 +483,17 @@ class BatchedDataset[U, R, I](Dataset):
memory, but should otherwise leave most interactions unchanged. memory, but should otherwise leave most interactions unchanged.
Parameters: Parameters:
frac: split fractions for datasets
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
shuffle_strata: shuffle the strata order before split is drawn. We shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a dataloader-level shuffle operation parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting will only change the order of the indices in the resulting
splits; only a shuffle of items inside the strata can change splits; only a shuffle of the strata order can change the
the actual content of the splits themselves. actual content of the splits themselves.
""" """
if by_attr == []: if by_attr == []:
@@ -544,6 +602,32 @@ class BatchedDataset[U, R, I](Dataset):
split_max_sizes: list[int] | None = None, split_max_sizes: list[int] | None = None,
shuffle_strata: bool = True, shuffle_strata: bool = True,
) -> None: ) -> None:
"""
Balance the distribution of provided attributes over dataset items.
This method sets the indices over the dataset according to the result
of the rebalancing. The indices are produced by the recursive
``_balance()`` method, which is necessarily separate due to the need
for a contained recursive approach that doesn't change the underlying
dataset during execution.
Parameters:
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
split_min_sizes: minimum allowed sizes of splits. Must have the
same length as ``by_attr``.
split_max_sizes: maximum allowed sizes of splits. Must have the
same length as ``by_attr``.
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of the strata order can change the
actual content of the splits themselves.
"""
self.indices = self._balance( self.indices = self._balance(
dataset, dataset,
by_attr, by_attr,
@@ -561,9 +645,27 @@ class BatchedDataset[U, R, I](Dataset):
shuffle_strata: bool = True, shuffle_strata: bool = True,
) -> list[int]: ) -> list[int]:
""" """
Note: behavior is a little odd for nested behavior; not exactly .. note::
perfectly uniform throughout. This is a little difficult: you can't
exactly know ahead of time the size of the subgroups across splits Behavior is a little odd for nested behavior; not exactly perfectly
uniform throughout. This is a little difficult: you can't exactly
know ahead of time the size of the subgroups across splits
Parameters:
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
split_min_sizes: minimum allowed sizes of splits. Must have the
same length as ``by_attr``.
split_max_sizes: maximum allowed sizes of splits. Must have the
same length as ``by_attr``.
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of the strata order can change the
actual content of the splits themselves.
""" """
if by_attr == []: if by_attr == []:
@@ -653,6 +755,9 @@ class BatchedDataset[U, R, I](Dataset):
dataset. The underlying data remain the same, but when indices get set, dataset. The underlying data remain the same, but when indices get set,
you're effectively applying a mask over any existing indices, always you're effectively applying a mask over any existing indices, always
operating *relative* to the existing mask. operating *relative* to the existing mask.
Parameters:
indices: list of indices to set
""" """
# manually set new size # manually set new size
@@ -680,6 +785,13 @@ class BatchedDataset[U, R, I](Dataset):
return self._dataset_len return self._dataset_len
def __getitem__(self, index: int) -> I: def __getitem__(self, index: int) -> I:
"""
Get the dataset item at the specified index.
Parameters:
index: index of item to retrieve
"""
item_data = self._get_item(index) item_data = self._get_item(index)
index = self.indices[index] index = self.indices[index]
@@ -888,7 +1000,7 @@ class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]): class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
""" """
Batched dataset where batches have arbitrary size. Batched dataset where batches may have arbitrary size.
Methods left for inheriting classes: Methods left for inheriting classes:

View File

@@ -107,8 +107,14 @@ class LSTM[K: RNNKwargs](Estimator[K]):
with torch.no_grad(): with torch.no_grad():
loss = next(self.loss(**kwargs)).item() loss = next(self.loss(**kwargs)).item()
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
mae = F.l1_loss(predictions, labels).item()
return { return {
"loss": loss, "loss": loss,
"mse": loss,
"mae": mae,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }
@@ -291,7 +297,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
logger.info(f"| > {self.output_dim=}") logger.info(f"| > {self.output_dim=}")
class ConvRNN[K: RNNKwargs](Estimator[K]): class ConvGRU[K: RNNKwargs](Estimator[K]):
""" """
Base recurrent convolutional architecture. Base recurrent convolutional architecture.
@@ -441,11 +447,18 @@ class ConvRNN[K: RNNKwargs](Estimator[K]):
with torch.no_grad(): with torch.no_grad():
loss = next(self.loss(**kwargs)).item() loss = next(self.loss(**kwargs)).item()
predictions = self(**kwargs)[0].squeeze(-1)
labels = kwargs["labels"]
mae = F.l1_loss(predictions, labels).item()
return { return {
"loss": loss, "loss": loss,
"mse": loss,
"mae": mae,
"grad_norm": get_grad_norm(self) "grad_norm": get_grad_norm(self)
} }
def optimizers( def optimizers(
self, self,
**kwargs: Unpack[OptimizerKwargs], **kwargs: Unpack[OptimizerKwargs],

View File

@@ -31,7 +31,15 @@ logger: logging.Logger = logging.getLogger(__name__)
class Trainer[I, K: EstimatorKwargs]: class Trainer[I, K: EstimatorKwargs]:
""" """
Training interface for updating ``Estimators`` with ``Datasets``. Training interface for optimizing parameters of ``Estimators`` with
``Datasets``.
This class is generic to a dataset item type ``I`` and an estimator kwarg
type ``K``. These are the two primary components ``Trainer`` objects need
to coordinate: they ultimately rely on a provided map to ensure data items
(type ``I``) from a dataset are appropriately routed as inputs to key
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
of type ``K``.
""" """
def __init__( def __init__(
@@ -43,8 +51,10 @@ class Trainer[I, K: EstimatorKwargs]:
) -> None: ) -> None:
""" """
Parameters: Parameters:
estimator: `Estimator` model object estimator: ``Estimator`` model object
device: device on which to carry out training device: device on which to carry out training
chkpt_dir: directory to write model checkpoints
tblog_dir: directory to write TensorBoard logs
""" """
self.device: str self.device: str
@@ -87,7 +97,7 @@ class Trainer[I, K: EstimatorKwargs]:
def reset(self) -> None: def reset(self) -> None:
""" """
Set base tracking parameters. Set initial tracking parameters for the primary training loop.
""" """
self._step: int = 0 self._step: int = 0
@@ -276,13 +286,13 @@ class Trainer[I, K: EstimatorKwargs]:
one should take care to synchronize the sample structure with `dataset` one should take care to synchronize the sample structure with `dataset`
to match that expected by ``self.estimator.loss(...)``. to match that expected by ``self.estimator.loss(...)``.
.. admonition:: On batch_estimator_map .. admonition:: On ``batch_estimator_map``
Dataloader collate functions are responsible for mapping a Dataloader collate functions are responsible for mapping a
collection of items into an item of collections, roughly speaking. collection of items into an item of collections, roughly speaking.
If items are tuples of tensors, If items are tuples of tensors,
.. code-block:: .. code-block:: text
[ [
( [1, 1], [1, 1] ), ( [1, 1], [1, 1] ),
@@ -293,7 +303,7 @@ class Trainer[I, K: EstimatorKwargs]:
the collate function maps back into the item skeleton, producing a the collate function maps back into the item skeleton, producing a
single tuple of (stacked) tensors single tuple of (stacked) tensors
.. code-block:: .. code-block:: text
( [[1, 1], ( [[1, 1],
[2, 2], [2, 2],
@@ -309,8 +319,13 @@ class Trainer[I, K: EstimatorKwargs]:
``K``). ``K``).
Parameters: Parameters:
dataset: dataset to train the estimator
batch_estimator_map: function mapping from batch data to expected
estimator kwargs
lr: learning rate (default: 1e-3) lr: learning rate (default: 1e-3)
eps: adam EPS (default: 1e-8) eps: adam EPS (default: 1e-8)
max_grad_norm: upper bound to use when clipping gradients. If left
as ``None``, no gradient clipping is performed.
max_epochs: maximum number of training epochs max_epochs: maximum number of training epochs
stop_after_epochs: number of epochs with stagnant validation losses stop_after_epochs: number of epochs with stagnant validation losses
to allow before early stopping. If training stops earlier, the to allow before early stopping. If training stops earlier, the
@@ -395,11 +410,7 @@ class Trainer[I, K: EstimatorKwargs]:
# save checkpoint # save checkpoint
if self._epoch % chkpt_every == 0: if self._epoch % chkpt_every == 0:
self.save_model( self.save_model(self._epoch, self.chkpt_dir, dir_prefix)
self._epoch,
self.chkpt_dir,
dir_prefix
)
self._epoch += 1 self._epoch += 1
@@ -493,7 +504,7 @@ class Trainer[I, K: EstimatorKwargs]:
def _summarize(self, writer: SummaryWriter, epoch: int) -> None: def _summarize(self, writer: SummaryWriter, epoch: int) -> None:
""" """
Flush the training summary to the TB summary writer. Flush the training summary to the TensorBoard summary writer.
""" """
summary_values = defaultdict(list) summary_values = defaultdict(list)
@@ -547,17 +558,18 @@ class Trainer[I, K: EstimatorKwargs]:
chkpt_dir.mkdir(parents=True, exist_ok=True) chkpt_dir.mkdir(parents=True, exist_ok=True)
chkpt_path.write_bytes(model_buff.getvalue()) chkpt_path.write_bytes(model_buff.getvalue())
def load_model( def load_model(self, epoch: int, chkpt_dir: str) -> None:
self,
epoch: int,
chkpt_dir: str,
) -> None:
""" """
Load a model checkpoint from a given epoch. Load a model checkpoint from a given epoch.
Note that this assumes the model was saved via `Trainer.save_model()`, Note that this assumes the model was saved via
and the estimator provided to this `Trainer` instance matches the ``Trainer.save_model()``, and the estimator provided to this
architecture of the checkpoint model being loaded. ``Trainer`` instance matches the architecture of the checkpoint model
being loaded.
Parameters:
epoch: epoch of saved model
chkpt_dir:
""" """
model_class = self.estimator.__class__.__name__ model_class = self.estimator.__class__.__name__

View File

@@ -8,4 +8,14 @@ class Transform[I]:
""" """
def __call__(self, item: I) -> I: def __call__(self, item: I) -> I:
"""
Apply transform to item.
Parameters:
item: item object to transform
Returns:
transformed item (same type ``I`` as input)
"""
raise NotImplementedError raise NotImplementedError

View File

@@ -0,0 +1,46 @@
text.usetex : False
mathtext.default : regular
font.family : sans-serif
font.sans-serif : DejaVu Sans
font.serif : DejaVu Serif
font.cursive : DejaVu Sans
mathtext.fontset : dejavuserif
font.size : 9
figure.titlesize : 9
legend.fontsize : 9
axes.titlesize : 9
axes.labelsize : 9
xtick.labelsize : 9
ytick.labelsize : 9
#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb'])
axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a'])
image.interpolation : nearest
image.resample : False
image.composite_image : True
axes.spines.left : True
axes.spines.bottom : True
axes.spines.top : False
axes.spines.right : False
axes.linewidth : 1
xtick.major.width : 1
xtick.minor.width : 1
ytick.major.width : 1
ytick.minor.width : 1
lines.linewidth : 1
lines.markersize : 1
savefig.dpi : 300
savefig.format : svg
savefig.bbox : tight
savefig.pad_inches : 0.1
svg.image_inline : True
svg.fonttype : none
legend.frameon : False

38
trainlib/utils/plot.py Normal file
View File

@@ -0,0 +1,38 @@
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
FILE = Path(__file__).parent.absolute()
class use_style:
def __init__(
self,
style: list[str] | None = None,
**kwargs: str,
) -> None:
super().__init__()
if style is None:
style = [str(Path(FILE, "custom.mplstyle"))]
self.style = style + [kwargs]
self.previous_style = {}
def __enter__(self) -> None:
self.previous_style = mpl.rcParams.copy()
if self.style is not None:
plt.style.use(self.style)
def __exit__(self, *args: str, **kwargs: str) -> None:
mpl.rcParams.update(self.previous_style)
def set_style(
style: list[str] | None = None,
**kwargs: str,
) -> None:
if style is None:
style = [str(Path(FILE, "custom.mplstyle"))]
plt.style.use(style + [kwargs])

21
trainlib/utils/session.py Normal file
View File

@@ -0,0 +1,21 @@
import random
import numpy as np
import torch
from torch import Tensor
def seed_all_backends(seed: int | Tensor | None = None) -> None:
"""Sets all python, numpy and pytorch seeds."""
if seed is None:
seed = int(torch.randint(1000000, size=(1,)))
else:
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

100
uv.lock generated
View File

@@ -248,9 +248,13 @@ dependencies = [
{ name = "cuda-pathfinder" }, { name = "cuda-pathfinder" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/8b/b4b2d1c7775fa403b64333e720cfcfccef8dcb9cdeb99947061ca5a77628/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf8bfaedc238f3b115d957d1fd6562b7e8435ba57f6d0e2f87d0e7149ccb2da5", size = 11570071, upload-time = "2025-10-21T14:51:47.472Z" },
{ url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" },
{ url = "https://files.pythonhosted.org/packages/ec/07/6aff13bc1e977e35aaa6b22f52b172e2890c608c6db22438cf7ed2bf43a6/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3adf4958dcf68ae7801a59b73fb00a8b37f8d0595060d66ceae111b1002de38d", size = 11566797, upload-time = "2025-10-21T14:51:54.581Z" },
{ url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" },
{ url = "https://files.pythonhosted.org/packages/1e/b5/96a6696e20c4ffd2b327f54c7d0fde2259bdb998d045c25d5dedbbe30290/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f53a7f453d4b2643d8663d036bafe29b5ba89eb904c133180f295df6dc151e5", size = 11624530, upload-time = "2025-10-21T14:52:01.539Z" },
{ url = "https://files.pythonhosted.org/packages/d1/af/6dfd8f2ed90b1d4719bc053ff8940e494640fe4212dc3dd72f383e4992da/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686", size = 11922703, upload-time = "2025-10-21T14:52:03.585Z" }, { url = "https://files.pythonhosted.org/packages/d1/af/6dfd8f2ed90b1d4719bc053ff8940e494640fe4212dc3dd72f383e4992da/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686", size = 11922703, upload-time = "2025-10-21T14:52:03.585Z" },
{ url = "https://files.pythonhosted.org/packages/39/73/d2fc40c043bac699c3880bf88d3cebe9d88410cd043795382826c93a89f0/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:20f2699d61d724de3eb3f3369d57e2b245f93085cab44fd37c3bea036cea1a6f", size = 11565056, upload-time = "2025-10-21T14:52:08.338Z" },
{ url = "https://files.pythonhosted.org/packages/6c/19/90ac264acc00f6df8a49378eedec9fd2db3061bf9263bf9f39fd3d8377c3/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee", size = 11924658, upload-time = "2025-10-21T14:52:10.411Z" }, { url = "https://files.pythonhosted.org/packages/6c/19/90ac264acc00f6df8a49378eedec9fd2db3061bf9263bf9f39fd3d8377c3/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee", size = 11924658, upload-time = "2025-10-21T14:52:10.411Z" },
] ]
@@ -861,6 +865,7 @@ name = "nvidia-cublas-cu12"
version = "12.8.4.1" version = "12.8.4.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" },
{ url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" },
] ]
@@ -869,6 +874,7 @@ name = "nvidia-cuda-cupti-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" },
{ url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" },
] ]
@@ -878,6 +884,7 @@ version = "12.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" },
{ url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" },
] ]
[[package]] [[package]]
@@ -885,6 +892,7 @@ name = "nvidia-cuda-runtime-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" },
{ url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" },
] ]
@@ -896,6 +904,7 @@ dependencies = [
{ name = "nvidia-cublas-cu12" }, { name = "nvidia-cublas-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
] ]
@@ -907,6 +916,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
] ]
@@ -916,6 +926,7 @@ version = "1.13.1.3"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" },
{ url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" },
] ]
[[package]] [[package]]
@@ -923,6 +934,7 @@ name = "nvidia-curand-cu12"
version = "10.3.9.90" version = "10.3.9.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" },
{ url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" },
] ]
@@ -936,6 +948,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
] ]
@@ -947,6 +960,7 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12" }, { name = "nvidia-nvjitlink-cu12" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
] ]
@@ -955,6 +969,7 @@ name = "nvidia-cusparselt-cu12"
version = "0.7.1" version = "0.7.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" },
{ url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" },
] ]
@@ -963,6 +978,7 @@ name = "nvidia-nccl-cu12"
version = "2.27.5" version = "2.27.5"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/1c/857979db0ef194ca5e21478a0612bcdbbe59458d7694361882279947b349/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:31432ad4d1fb1004eb0c56203dc9bc2178a1ba69d1d9e02d64a6938ab5e40e7a", size = 322400625, upload-time = "2025-06-26T04:11:04.496Z" },
{ url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" },
] ]
@@ -972,6 +988,7 @@ version = "12.8.93"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" },
{ url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" },
] ]
[[package]] [[package]]
@@ -979,6 +996,7 @@ name = "nvidia-nvshmem-cu12"
version = "3.4.5" version = "3.4.5"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/1d/6a/03aa43cc9bd3ad91553a88b5f6fb25ed6a3752ae86ce2180221962bc2aa5/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15", size = 138936938, upload-time = "2025-09-06T00:32:05.589Z" },
{ url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" },
] ]
@@ -987,6 +1005,7 @@ name = "nvidia-nvtx-cu12"
version = "12.8.90" version = "12.8.90"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" },
{ url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
] ]
@@ -1393,14 +1412,14 @@ wheels = [
[[package]] [[package]]
name = "sphinx-autodoc-typehints" name = "sphinx-autodoc-typehints"
version = "3.9.5" version = "3.9.7"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "sphinx" }, { name = "sphinx" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/58/ec/21bd9babcfeb9930a73011257002d5cfa5fd30667b8de6d76dbaf8275dfb/sphinx_autodoc_typehints-3.9.5.tar.gz", hash = "sha256:60e646efb7c352a0e98f34dd7fdcde4527fbbdbdf30371ff8321b6b3eb1fd37d", size = 63249, upload-time = "2026-03-02T19:58:07.974Z" } sdist = { url = "https://files.pythonhosted.org/packages/f4/06/da2d9e98b3f7f0df144496e62f453e0025f129bccc7a6076b8ceae6047b1/sphinx_autodoc_typehints-3.9.7.tar.gz", hash = "sha256:70f3dd4e4dd815ae30e5d3848a26dca71fb5e7fcf8f37cf8b840dc8afdf07e82", size = 68689, upload-time = "2026-03-05T18:33:40.829Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/7f/cb/80c250f47a0ca5ac67d82f14811b4068a551a12b4790b085ffdb900de427/sphinx_autodoc_typehints-3.9.5-py3-none-any.whl", hash = "sha256:c94f88a90b6c61a7a6686cb77b410e46e077712838387e6cf22d69e85cfd06a5", size = 34763, upload-time = "2026-03-02T19:58:06.028Z" }, { url = "https://files.pythonhosted.org/packages/a4/a0/e7d3365dabfa79a1b2ac7d3122b5b22b401a9c4d5e4eadc5e13b88c63a2c/sphinx_autodoc_typehints-3.9.7-py3-none-any.whl", hash = "sha256:dd73f6a32adef0d8208f6f7d99254e1880259c77db7b4a91648345d45202d48e", size = 36691, upload-time = "2026-03-05T18:33:38.983Z" },
] ]
[[package]] [[package]]
@@ -1542,52 +1561,47 @@ wheels = [
[[package]] [[package]]
name = "torch" name = "torch"
version = "2.10.0" version = "2.10.0+cu128"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://download.pytorch.org/whl/cu128" }
dependencies = [ dependencies = [
{ name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "cuda-bindings", marker = "sys_platform == 'linux'" },
{ name = "filelock" }, { name = "filelock" },
{ name = "fsspec" }, { name = "fsspec" },
{ name = "jinja2" }, { name = "jinja2" },
{ name = "networkx" }, { name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux'" },
{ name = "setuptools" }, { name = "setuptools" },
{ name = "sympy" }, { name = "sympy" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "triton", marker = "sys_platform == 'linux'" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:bdbcc703382f948e951c063448c9406bf38ce66c41dd698d9e2733fcf96c037a" },
{ url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7b4bd23ed63de97456fcc81c26fea9f02ee02ce1112111c4dac0d8cfe574b23e" },
{ url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:4d1b0b49c54223c7c04050b49eac141d77b6edbc34aea1dfc74a6fdb661baa8c" },
{ url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f1f8b840c64b645a4bc61a393db48effb9c92b2dc26c8373873911f0750d1ea7" },
{ url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:23f58258012bcf1c349cb22af387e33aadca7f83ea617b080e774eb41e4fe8ff" },
{ url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:01b216e097b17a5277cfb47c383cdcacf06abeadcb0daca0c76b59e72854c3b6" },
{ url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:c42377bc2607e3e1c60da71b792fb507c3938c87fd6edab8b21c59c91473c36d" },
{ url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:37d71feea068776855686a1512058df3f19f6f040a151f055aa746601678744f" },
{ url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:c57017ca29e62271e362fdeee7d20070e254755a5148b30b553d8a10fc83c7ef" },
{ url = "https://files.pythonhosted.org/packages/4f/93/716b5ac0155f1be70ed81bacc21269c3ece8dba0c249b9994094110bfc51/torch-2.10.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:bf0d9ff448b0218e0433aeb198805192346c4fd659c852370d5cc245f602a06a", size = 79464992, upload-time = "2026-01-21T16:23:05.162Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:777461f50b2daf77e4bdd8e2ad34bdfc5a993bf1bdf2ab9ef39f5edfe4e9c12b" },
{ url = "https://files.pythonhosted.org/packages/69/2b/51e663ff190c9d16d4a8271203b71bc73a16aa7619b9f271a69b9d4a936b/torch-2.10.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:233aed0659a2503b831d8a67e9da66a62c996204c0bba4f4c442ccc0c68a3f60", size = 146018567, upload-time = "2026-01-21T16:22:23.393Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7bcba6a7c5f0987a13298b1ca843155dcceceac758fa3c7ccd5c7af4059a1080" },
{ url = "https://files.pythonhosted.org/packages/5e/cd/4b95ef7f293b927c283db0b136c42be91c8ec6845c44de0238c8c23bdc80/torch-2.10.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:682497e16bdfa6efeec8cde66531bc8d1fbbbb4d8788ec6173c089ed3cc2bfe5", size = 915721646, upload-time = "2026-01-21T16:21:16.983Z" }, { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:70d89143c956389d4806cb4e5fe0b1129fe0db280e1073288d17fa76c101cba4" },
{ url = "https://files.pythonhosted.org/packages/56/97/078a007208f8056d88ae43198833469e61a0a355abc0b070edd2c085eb9a/torch-2.10.0-cp314-cp314-win_amd64.whl", hash = "sha256:6528f13d2a8593a1a412ea07a99812495bec07e9224c28b2a25c0a30c7da025c", size = 113752373, upload-time = "2026-01-21T16:22:13.471Z" },
{ url = "https://files.pythonhosted.org/packages/d8/94/71994e7d0d5238393df9732fdab607e37e2b56d26a746cb59fdb415f8966/torch-2.10.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:f5ab4ba32383061be0fb74bda772d470140a12c1c3b58a0cfbf3dae94d164c28", size = 79850324, upload-time = "2026-01-21T16:22:09.494Z" },
{ url = "https://files.pythonhosted.org/packages/e2/65/1a05346b418ea8ccd10360eef4b3e0ce688fba544e76edec26913a8d0ee0/torch-2.10.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:716b01a176c2a5659c98f6b01bf868244abdd896526f1c692712ab36dbaf9b63", size = 146006482, upload-time = "2026-01-21T16:22:18.42Z" },
{ url = "https://files.pythonhosted.org/packages/1d/b9/5f6f9d9e859fc3235f60578fa64f52c9c6e9b4327f0fe0defb6de5c0de31/torch-2.10.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:d8f5912ba938233f86361e891789595ff35ca4b4e2ac8fe3670895e5976731d6", size = 915613050, upload-time = "2026-01-21T16:20:49.035Z" },
{ url = "https://files.pythonhosted.org/packages/66/4d/35352043ee0eaffdeff154fad67cd4a31dbed7ff8e3be1cc4549717d6d51/torch-2.10.0-cp314-cp314t-win_amd64.whl", hash = "sha256:71283a373f0ee2c89e0f0d5f446039bdabe8dbc3c9ccf35f0f784908b0acd185", size = 113995816, upload-time = "2026-01-21T16:22:05.312Z" },
] ]
[[package]] [[package]]
@@ -1623,7 +1637,7 @@ wheels = [
[[package]] [[package]]
name = "trainlib" name = "trainlib"
version = "0.1.0" version = "0.1.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "colorama" }, { name = "colorama" },
@@ -1662,7 +1676,7 @@ requires-dist = [
{ name = "sphinx-autodoc-typehints", marker = "extra == 'doc'" }, { name = "sphinx-autodoc-typehints", marker = "extra == 'doc'" },
{ name = "sphinx-togglebutton", marker = "extra == 'doc'" }, { name = "sphinx-togglebutton", marker = "extra == 'doc'" },
{ name = "tensorboard", specifier = ">=2.20.0" }, { name = "tensorboard", specifier = ">=2.20.0" },
{ name = "torch", specifier = ">=2.5.1" }, { name = "torch", index = "https://download.pytorch.org/whl/cu128" },
{ name = "tqdm", specifier = ">=4.67.1" }, { name = "tqdm", specifier = ">=4.67.1" },
] ]
provides-extras = ["dev", "doc", "test"] provides-extras = ["dev", "doc", "test"]
@@ -1681,9 +1695,13 @@ name = "triton"
version = "3.6.0" version = "3.6.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" },
{ url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" },
{ url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" },
{ url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" },
{ url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" },
{ url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" },
{ url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" },
{ url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" },
] ]