From e867bc0e7f68ab3a4e56806c63581f2c92935f65 Mon Sep 17 00:00:00 2001 From: smgr Date: Sat, 7 Mar 2026 03:10:13 -0800 Subject: [PATCH] add plot styles, clean up package-wide docstrings --- doc/conf.py | 5 +- doc/index.md | 4 +- pyproject.toml | 10 +- trainlib/dataset.py | 216 +++++++++++++++++++++++++-------- trainlib/estimators/rnn.py | 15 ++- trainlib/trainer.py | 52 +++++--- trainlib/transform.py | 10 ++ trainlib/utils/custom.mplstyle | 46 +++++++ trainlib/utils/plot.py | 38 ++++++ trainlib/utils/session.py | 21 ++++ uv.lock | 100 ++++++++------- 11 files changed, 398 insertions(+), 119 deletions(-) create mode 100644 trainlib/utils/custom.mplstyle create mode 100644 trainlib/utils/plot.py create mode 100644 trainlib/utils/session.py diff --git a/doc/conf.py b/doc/conf.py index 9753af1..e287218 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -6,8 +6,8 @@ # -- Project information ------------------------------------------------------ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = "" -copyright = "2025, Sam Griesemer" +project = "trainlib" +copyright = "2026, Sam Griesemer" author = "Sam Griesemer" # -- General configuration ---------------------------------------------------- @@ -44,6 +44,7 @@ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] html_theme = "furo" html_static_path = ["_static"] + # html_sidebars = { # '**': ['/modules.html'], # } diff --git a/doc/index.md b/doc/index.md index f8fc36c..1b73cee 100644 --- a/doc/index.md +++ b/doc/index.md @@ -1,4 +1,4 @@ -# `` package docs +# `trainlib` package docs {ref}`genindex` {ref}`modindex` {ref}`search` @@ -14,7 +14,7 @@ :maxdepth: 3 :caption: Autoref -_autoref/.rst +_autoref/index.rst ``` ```{toctree} diff --git a/pyproject.toml b/pyproject.toml index b491b97..3dda38c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,11 +24,11 @@ classifiers = [ "Intended Audience :: End Users/Desktop", ] dependencies = [ + "torch", "colorama>=0.4.6", "matplotlib>=3.10.8", "numpy>=2.4.1", "tensorboard>=2.20.0", - "torch>=2.5.1", "tqdm>=4.67.1", ] @@ -82,3 +82,11 @@ force-sort-within-sections = false quote-style = "double" indent-style = "space" docstring-code-format = true + +[tool.uv.sources] +torch = { index = "pytorch" } + +[[tool.uv.index]] +name = "pytorch" +url = "https://download.pytorch.org/whl/cu128" +explicit = true diff --git a/trainlib/dataset.py b/trainlib/dataset.py index 8429bbc..fdc4341 100644 --- a/trainlib/dataset.py +++ b/trainlib/dataset.py @@ -7,12 +7,12 @@ could have a single dataset definition for a particular concrete dataset, and so long as we're talking about the same items, it can be instantiated using *any domain*. You wouldn't need specific subclasses for disk or - network or in-memory; you can tell it directly at runtime. + network or in-memory structures; you can tell it directly at runtime. That's an eventually possibility, anyway. As it stands, however, this is effectively impossible: - You can't easily abstract the batch -> item splitting process, i.e., + You can't easily abstract the batch-to-item splitting process, i.e., ``_process_batch_data()``. A list-based version of the dataset you're trying to define might have an individual item tuple at every index, whereas a disk-based version might have tuples batched across a few files. @@ -41,22 +41,22 @@ accepting a batch processing method upon instantiation rather than structurally bolting it into the ``Dataset`` definition. This would require knowledge of the item structure ``I`` as well as the ``Domain[U, R]``, so - such a function will always have to be (I, U, R)-dependent. It nevertheless - would take out some of the pain of having to define new dataset classes; - instead, you'd just need to define the batch processing method. I see this - as a worse alternative to just defining *inside* a safe context like a new - dataset class: you know the types you have to respect, and you stick that - method exactly in a context where it's understood. Freeing this up doesn't - lighten the burden of processing logic, it just changes *when* it has to be - provided, and that's not worth much (to me) in this case given the bump in - complexity. (Taking this to the extreme: you could supply *all* of an - object's methods "dynamically" and glue them together at runtime so long as - they all played nice. But wherever you were "laying them out" beforehand is - exactly the job of a class to begin with, so you don't end up with anything - more dynamic. All we're really discussing here is pushing around - unavoidable complexity inside and outside of the "class walls," and in the - particular case of ``_process_batch_data()``, it feels much better when - it's on the inside.) + such a function will always have to be ``(I, U, R)``-dependent. It + nevertheless would take out some of the pain of having to define new + dataset classes; instead, you'd just need to define the batch processing + method. I see this as a worse alternative to just defining *inside* a safe + context like a new dataset class: you know the types you have to respect, + and you stick that method exactly in a context where it's understood. + Freeing this up doesn't lighten the burden of processing logic, it just + changes *when* it has to be provided, and that's not worth much (to me) in + this case given the bump in complexity. (Taking this to the extreme: you + could supply *all* of an object's methods "dynamically" and glue them + together at runtime so long as they all played nice. But wherever you were + "laying them out" beforehand is exactly the job of a class to begin with, + so you don't end up with anything more dynamic. All we're really discussing + here is pushing around unavoidable complexity inside and outside of the + "class walls," and in the particular case of ``_process_batch_data()``, it + feels much better when it's on the inside.) .. admonition:: Holding area @@ -114,6 +114,7 @@ T = TypeVar("T", bound=NamedTuple) + class NamedTupleDataset[I](Dataset): def __init__(self, data_list: list[I]) -> None: self.data_list = data_list @@ -161,43 +162,76 @@ class BatchedDataset[U, R, I](Dataset): The class is generic over a URI type ``U``, a resource type ``R`` (both of which are used to concretize a domain ``Domain[U, R]``), and an item type - ``T`` (which has a ``tuple`` upper bound). + ``I``. - - .. admonition:: Pipeline overview + .. admonition:: Batch and item processing flow - .. code-block:: python - - Domain -> [U] (get _batch_uris) - U -> R (domain access ; Rs provide batches) - R -> [I] (cache here ; _process_batch_data to use load_transform) - [I] -> I (human item obj ; _get_item) - I -> **P (final packed item ; __getitem__ to use transform) + .. code-block:: text + Domain -> [U] :: self._batch_uris = list(domain) + + Grab all URIs from Domain iterators. This is made concrete + early to allow for Dataset sizing, and we need a Sequence + representation to map integer batch indices into Domains, i.e., + when getting the corresponding URI: + + batch_uri = self._batch_uris[batch_index] + + We let Domains implement iterators over their URIs, but + explicitly exhaust when initializing Datasets. + + U -> R :: batch_data = self.domain[batch_uri] + + Retrieve resource from domain. Resources are viewed as batched + data, even if only wrapping single items (happens in trivial + settings). + + R -> [I] :: self._process_batch_data(batch_data, batch_index) + + Possibly domain-specific batch processing of resource data into + explicit Sequence-like structures of items, each of which is + subject to the provided pre_transform. Processed batches at + this stage are cached (if enabled). + + [I] -> I :: self.get_batch(batch_index)[index_in_batch] + + Select individual items from batches in _get_item. At this + stage, items are in intermediate states and pulled from the + cached batches. + + I -> I :: self._process_item_data(item_data, index) + + Produce final items with __getitem__, getting intermediate + items via _get_item and applying the provided post_transform. + + .. note:: Note^1: as far as positioning, this class is meant to play nice with PyTorch DataLoaders, hence the inheritance from ``torch.Dataset``. The value add for this over the ``torch.Dataset`` base is almost entirely in the logic it implements to map out of *batched resources* that are holding data, and flattening it out into typical dataset items. There - are also some QoL items when it comes to splitting and balancing - samples. + are also some QoL features when it comes to splitting and balancing + samples. - Note^2: even though ``Domains`` implement iterators over their URIs, - this doesn't imply a ``BatchedDataset`` is iterable. This just means we - can walk over the resources that provide data, but we don't necessarily - presuppose an ordered walk over samples within batches. Point being: - ``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate - superclass, even when we're working around iterable ``Domains``. + .. note:: + Even though ``Domains`` implement iterators over their URIs, this + doesn't imply a ``BatchedDataset`` is iterable. This just means we + can walk over the resources that provide data, but we don't + necessarily presuppose an ordered walk over samples within batches. + Point being: ``torch.Dataset``, not ``torch.IterableDataset``, is + the appropriate superclass, even when we're working around iterable + ``Domains``. - Note^3: transforms are expected to operate on ``I``-items and produce - ``I``-items. They shouldn't be the "introducers" of ``I`` types from - some other intermediate representation, nor should they map from ``I`` - to something else. Point being: the dataset definition should be able - to map resources ``R`` to ``I`` without a transform: that much should - be baked into the class definition. If you find you're expecting the - transform to do that for you, you should consider pulling in some - common structure across the allowed transforms and make it a fixed part - of the class. + .. note:: + Transforms are expected to operate on ``I``-items and produce + ``I``-items. They shouldn't be the "introducers" of ``I`` types + from some other intermediate representation, nor should they map + from ``I`` to something else. Point being: the dataset definition + should be able to map resources ``R`` to ``I`` without a transform: + that much should be baked into the class definition. If you find + you're expecting the transform to do that for you, you should + consider pulling in some common structure across the allowed + transforms and make it a fixed part of the class. """ def __init__( @@ -211,6 +245,7 @@ class BatchedDataset[U, R, I](Dataset): ) -> None: """ Parameters: + domain: ``Domain`` object providing access to batched data pre_transform: transform to apply over items during loading (in ``_process_batch_data()``), i.e., *before* going into persistent storage @@ -220,6 +255,7 @@ class BatchedDataset[U, R, I](Dataset): batch_cache_limit: the max number of max batches to cache at any one time preload: whether to load all data into memory during instantiation + num_workers: number of workers to use when preloading data """ self.domain = domain @@ -259,6 +295,9 @@ class BatchedDataset[U, R, I](Dataset): The behavior of this method can vary depending on what we know about batch sizes, and should therefore be implemented by inheriting classes. + Parameters: + item_index: index of item + Returns: batch_index: int index_in_batch: int @@ -302,6 +341,10 @@ class BatchedDataset[U, R, I](Dataset): place to use a provided ``post_transform``; items are pulled from the cache (if enabled) and processed before being returned as the final tuple outputs (so this processing is not persistent). + + Parameters: + item_data: item data + item_index: index of item """ raise NotImplementedError @@ -317,6 +360,9 @@ class BatchedDataset[U, R, I](Dataset): Note that return values from `__getitem__()` are "cleaned up" versions of this representation, with minimal info needed for training. + + Parameters: + item_index: index of item """ if item_index >= len(self): @@ -355,6 +401,9 @@ class BatchedDataset[U, R, I](Dataset): any delayed reads here. There's no way around needing to see all batch data at once here, and we don't want to make that ambiguous: ``list`` output type it is. + + Parameters: + batch_index: index of batch """ logger.debug("Batch cache miss, reading from root...") @@ -374,6 +423,9 @@ class BatchedDataset[U, R, I](Dataset): Can be useful when dynamically pulling data (as it's requested) isn't desired. Requires that `cache_sample_limit=None`, i.e., the cache won't continually remove previous batches as they're loaded. + + Parameters: + num_workers: number of parallel workers to use for data loading """ assert self.batch_cache_limit is None, "Preloading under cache limit" @@ -431,11 +483,17 @@ class BatchedDataset[U, R, I](Dataset): memory, but should otherwise leave most interactions unchanged. Parameters: + frac: split fractions for datasets + dataset: dataset to split, defaults to ``self``. Facilitates + recursive splitting when multi-attribute splits are needed. + by_attr: attribute or attributes to use when grouping strata for + dataset splits. Defaults to ``None``, which will use item + indices. shuffle_strata: shuffle the strata order before split is drawn. We - parameterize this because a dataloader-level shuffle operation + parameterize this because a Dataloader-level shuffle operation will only change the order of the indices in the resulting - splits; only a shuffle of items inside the strata can change - the actual content of the splits themselves. + splits; only a shuffle of the strata order can change the + actual content of the splits themselves. """ if by_attr == []: @@ -544,6 +602,32 @@ class BatchedDataset[U, R, I](Dataset): split_max_sizes: list[int] | None = None, shuffle_strata: bool = True, ) -> None: + """ + Balance the distribution of provided attributes over dataset items. + + This method sets the indices over the dataset according to the result + of the rebalancing. The indices are produced by the recursive + ``_balance()`` method, which is necessarily separate due to the need + for a contained recursive approach that doesn't change the underlying + dataset during execution. + + Parameters: + dataset: dataset to split, defaults to ``self``. Facilitates + recursive splitting when multi-attribute splits are needed. + by_attr: attribute or attributes to use when grouping strata for + dataset splits. Defaults to ``None``, which will use item + indices. + split_min_sizes: minimum allowed sizes of splits. Must have the + same length as ``by_attr``. + split_max_sizes: maximum allowed sizes of splits. Must have the + same length as ``by_attr``. + shuffle_strata: shuffle the strata order before split is drawn. We + parameterize this because a Dataloader-level shuffle operation + will only change the order of the indices in the resulting + splits; only a shuffle of the strata order can change the + actual content of the splits themselves. + """ + self.indices = self._balance( dataset, by_attr, @@ -561,9 +645,27 @@ class BatchedDataset[U, R, I](Dataset): shuffle_strata: bool = True, ) -> list[int]: """ - Note: behavior is a little odd for nested behavior; not exactly - perfectly uniform throughout. This is a little difficult: you can't - exactly know ahead of time the size of the subgroups across splits + .. note:: + + Behavior is a little odd for nested behavior; not exactly perfectly + uniform throughout. This is a little difficult: you can't exactly + know ahead of time the size of the subgroups across splits + + Parameters: + dataset: dataset to split, defaults to ``self``. Facilitates + recursive splitting when multi-attribute splits are needed. + by_attr: attribute or attributes to use when grouping strata for + dataset splits. Defaults to ``None``, which will use item + indices. + split_min_sizes: minimum allowed sizes of splits. Must have the + same length as ``by_attr``. + split_max_sizes: maximum allowed sizes of splits. Must have the + same length as ``by_attr``. + shuffle_strata: shuffle the strata order before split is drawn. We + parameterize this because a Dataloader-level shuffle operation + will only change the order of the indices in the resulting + splits; only a shuffle of the strata order can change the + actual content of the splits themselves. """ if by_attr == []: @@ -653,6 +755,9 @@ class BatchedDataset[U, R, I](Dataset): dataset. The underlying data remain the same, but when indices get set, you're effectively applying a mask over any existing indices, always operating *relative* to the existing mask. + + Parameters: + indices: list of indices to set """ # manually set new size @@ -680,6 +785,13 @@ class BatchedDataset[U, R, I](Dataset): return self._dataset_len def __getitem__(self, index: int) -> I: + """ + Get the dataset item at the specified index. + + Parameters: + index: index of item to retrieve + """ + item_data = self._get_item(index) index = self.indices[index] @@ -888,7 +1000,7 @@ class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]): class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]): """ - Batched dataset where batches have arbitrary size. + Batched dataset where batches may have arbitrary size. Methods left for inheriting classes: diff --git a/trainlib/estimators/rnn.py b/trainlib/estimators/rnn.py index 0ca7ce6..8ff620b 100644 --- a/trainlib/estimators/rnn.py +++ b/trainlib/estimators/rnn.py @@ -107,8 +107,14 @@ class LSTM[K: RNNKwargs](Estimator[K]): with torch.no_grad(): loss = next(self.loss(**kwargs)).item() + predictions = self(**kwargs)[0] + labels = kwargs["labels"] + mae = F.l1_loss(predictions, labels).item() + return { "loss": loss, + "mse": loss, + "mae": mae, "grad_norm": get_grad_norm(self) } @@ -291,7 +297,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]): logger.info(f"| > {self.output_dim=}") -class ConvRNN[K: RNNKwargs](Estimator[K]): +class ConvGRU[K: RNNKwargs](Estimator[K]): """ Base recurrent convolutional architecture. @@ -441,11 +447,18 @@ class ConvRNN[K: RNNKwargs](Estimator[K]): with torch.no_grad(): loss = next(self.loss(**kwargs)).item() + predictions = self(**kwargs)[0].squeeze(-1) + labels = kwargs["labels"] + mae = F.l1_loss(predictions, labels).item() + return { "loss": loss, + "mse": loss, + "mae": mae, "grad_norm": get_grad_norm(self) } + def optimizers( self, **kwargs: Unpack[OptimizerKwargs], diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 0d6882c..7055dcd 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -31,7 +31,15 @@ logger: logging.Logger = logging.getLogger(__name__) class Trainer[I, K: EstimatorKwargs]: """ - Training interface for updating ``Estimators`` with ``Datasets``. + Training interface for optimizing parameters of ``Estimators`` with + ``Datasets``. + + This class is generic to a dataset item type ``I`` and an estimator kwarg + type ``K``. These are the two primary components ``Trainer`` objects need + to coordinate: they ultimately rely on a provided map to ensure data items + (type ``I``) from a dataset are appropriately routed as inputs to key + estimator methods (like ``forward()`` and ``loss()``), which accept inputs + of type ``K``. """ def __init__( @@ -43,8 +51,10 @@ class Trainer[I, K: EstimatorKwargs]: ) -> None: """ Parameters: - estimator: `Estimator` model object + estimator: ``Estimator`` model object device: device on which to carry out training + chkpt_dir: directory to write model checkpoints + tblog_dir: directory to write TensorBoard logs """ self.device: str @@ -87,7 +97,7 @@ class Trainer[I, K: EstimatorKwargs]: def reset(self) -> None: """ - Set base tracking parameters. + Set initial tracking parameters for the primary training loop. """ self._step: int = 0 @@ -276,13 +286,13 @@ class Trainer[I, K: EstimatorKwargs]: one should take care to synchronize the sample structure with `dataset` to match that expected by ``self.estimator.loss(...)``. - .. admonition:: On batch_estimator_map + .. admonition:: On ``batch_estimator_map`` Dataloader collate functions are responsible for mapping a collection of items into an item of collections, roughly speaking. If items are tuples of tensors, - .. code-block:: + .. code-block:: text [ ( [1, 1], [1, 1] ), @@ -293,7 +303,7 @@ class Trainer[I, K: EstimatorKwargs]: the collate function maps back into the item skeleton, producing a single tuple of (stacked) tensors - .. code-block:: + .. code-block:: text ( [[1, 1], [2, 2], @@ -309,8 +319,13 @@ class Trainer[I, K: EstimatorKwargs]: ``K``). Parameters: + dataset: dataset to train the estimator + batch_estimator_map: function mapping from batch data to expected + estimator kwargs lr: learning rate (default: 1e-3) eps: adam EPS (default: 1e-8) + max_grad_norm: upper bound to use when clipping gradients. If left + as ``None``, no gradient clipping is performed. max_epochs: maximum number of training epochs stop_after_epochs: number of epochs with stagnant validation losses to allow before early stopping. If training stops earlier, the @@ -395,11 +410,7 @@ class Trainer[I, K: EstimatorKwargs]: # save checkpoint if self._epoch % chkpt_every == 0: - self.save_model( - self._epoch, - self.chkpt_dir, - dir_prefix - ) + self.save_model(self._epoch, self.chkpt_dir, dir_prefix) self._epoch += 1 @@ -493,7 +504,7 @@ class Trainer[I, K: EstimatorKwargs]: def _summarize(self, writer: SummaryWriter, epoch: int) -> None: """ - Flush the training summary to the TB summary writer. + Flush the training summary to the TensorBoard summary writer. """ summary_values = defaultdict(list) @@ -547,17 +558,18 @@ class Trainer[I, K: EstimatorKwargs]: chkpt_dir.mkdir(parents=True, exist_ok=True) chkpt_path.write_bytes(model_buff.getvalue()) - def load_model( - self, - epoch: int, - chkpt_dir: str, - ) -> None: + def load_model(self, epoch: int, chkpt_dir: str) -> None: """ Load a model checkpoint from a given epoch. - Note that this assumes the model was saved via `Trainer.save_model()`, - and the estimator provided to this `Trainer` instance matches the - architecture of the checkpoint model being loaded. + Note that this assumes the model was saved via + ``Trainer.save_model()``, and the estimator provided to this + ``Trainer`` instance matches the architecture of the checkpoint model + being loaded. + + Parameters: + epoch: epoch of saved model + chkpt_dir: """ model_class = self.estimator.__class__.__name__ diff --git a/trainlib/transform.py b/trainlib/transform.py index a84d009..f5737dd 100644 --- a/trainlib/transform.py +++ b/trainlib/transform.py @@ -8,4 +8,14 @@ class Transform[I]: """ def __call__(self, item: I) -> I: + """ + Apply transform to item. + + Parameters: + item: item object to transform + + Returns: + transformed item (same type ``I`` as input) + """ + raise NotImplementedError diff --git a/trainlib/utils/custom.mplstyle b/trainlib/utils/custom.mplstyle new file mode 100644 index 0000000..7b9558c --- /dev/null +++ b/trainlib/utils/custom.mplstyle @@ -0,0 +1,46 @@ +text.usetex : False +mathtext.default : regular + +font.family : sans-serif +font.sans-serif : DejaVu Sans +font.serif : DejaVu Serif +font.cursive : DejaVu Sans +mathtext.fontset : dejavuserif +font.size : 9 +figure.titlesize : 9 +legend.fontsize : 9 +axes.titlesize : 9 +axes.labelsize : 9 +xtick.labelsize : 9 +ytick.labelsize : 9 + +#axes.prop_cycle : cycler('color', ['4f7dd5', 'af7031', '55905e', 'd84739', '888348', 'b75e8b', '2f8f99', '9862cb']) +axes.prop_cycle : cycler('color', ['5e8de4', 'c38141', '67a771', 'e15344', '9e9858', '41a6b0', 'a46fd7', 'c86d9a']) + +image.interpolation : nearest +image.resample : False +image.composite_image : True + +axes.spines.left : True +axes.spines.bottom : True +axes.spines.top : False +axes.spines.right : False + +axes.linewidth : 1 +xtick.major.width : 1 +xtick.minor.width : 1 +ytick.major.width : 1 +ytick.minor.width : 1 + +lines.linewidth : 1 +lines.markersize : 1 + +savefig.dpi : 300 +savefig.format : svg +savefig.bbox : tight +savefig.pad_inches : 0.1 + +svg.image_inline : True +svg.fonttype : none + +legend.frameon : False diff --git a/trainlib/utils/plot.py b/trainlib/utils/plot.py new file mode 100644 index 0000000..70d9ffc --- /dev/null +++ b/trainlib/utils/plot.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import matplotlib as mpl +import matplotlib.pyplot as plt + +FILE = Path(__file__).parent.absolute() + + +class use_style: + def __init__( + self, + style: list[str] | None = None, + **kwargs: str, + ) -> None: + super().__init__() + + if style is None: + style = [str(Path(FILE, "custom.mplstyle"))] + self.style = style + [kwargs] + self.previous_style = {} + + def __enter__(self) -> None: + self.previous_style = mpl.rcParams.copy() + if self.style is not None: + plt.style.use(self.style) + + def __exit__(self, *args: str, **kwargs: str) -> None: + mpl.rcParams.update(self.previous_style) + + +def set_style( + style: list[str] | None = None, + **kwargs: str, +) -> None: + if style is None: + style = [str(Path(FILE, "custom.mplstyle"))] + + plt.style.use(style + [kwargs]) diff --git a/trainlib/utils/session.py b/trainlib/utils/session.py new file mode 100644 index 0000000..a46d072 --- /dev/null +++ b/trainlib/utils/session.py @@ -0,0 +1,21 @@ +import random + +import numpy as np +import torch +from torch import Tensor + + +def seed_all_backends(seed: int | Tensor | None = None) -> None: + """Sets all python, numpy and pytorch seeds.""" + + if seed is None: + seed = int(torch.randint(1000000, size=(1,))) + else: + seed = int(seed) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False diff --git a/uv.lock b/uv.lock index ac76e4a..8263fed 100644 --- a/uv.lock +++ b/uv.lock @@ -248,9 +248,13 @@ dependencies = [ { name = "cuda-pathfinder" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/05/8b/b4b2d1c7775fa403b64333e720cfcfccef8dcb9cdeb99947061ca5a77628/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf8bfaedc238f3b115d957d1fd6562b7e8435ba57f6d0e2f87d0e7149ccb2da5", size = 11570071, upload-time = "2025-10-21T14:51:47.472Z" }, { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, + { url = "https://files.pythonhosted.org/packages/ec/07/6aff13bc1e977e35aaa6b22f52b172e2890c608c6db22438cf7ed2bf43a6/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3adf4958dcf68ae7801a59b73fb00a8b37f8d0595060d66ceae111b1002de38d", size = 11566797, upload-time = "2025-10-21T14:51:54.581Z" }, { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, + { url = "https://files.pythonhosted.org/packages/1e/b5/96a6696e20c4ffd2b327f54c7d0fde2259bdb998d045c25d5dedbbe30290/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f53a7f453d4b2643d8663d036bafe29b5ba89eb904c133180f295df6dc151e5", size = 11624530, upload-time = "2025-10-21T14:52:01.539Z" }, { url = "https://files.pythonhosted.org/packages/d1/af/6dfd8f2ed90b1d4719bc053ff8940e494640fe4212dc3dd72f383e4992da/cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686", size = 11922703, upload-time = "2025-10-21T14:52:03.585Z" }, + { url = "https://files.pythonhosted.org/packages/39/73/d2fc40c043bac699c3880bf88d3cebe9d88410cd043795382826c93a89f0/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:20f2699d61d724de3eb3f3369d57e2b245f93085cab44fd37c3bea036cea1a6f", size = 11565056, upload-time = "2025-10-21T14:52:08.338Z" }, { url = "https://files.pythonhosted.org/packages/6c/19/90ac264acc00f6df8a49378eedec9fd2db3061bf9263bf9f39fd3d8377c3/cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee", size = 11924658, upload-time = "2025-10-21T14:52:10.411Z" }, ] @@ -861,6 +865,7 @@ name = "nvidia-cublas-cu12" version = "12.8.4.1" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/29/99/db44d685f0e257ff0e213ade1964fc459b4a690a73293220e98feb3307cf/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0", size = 590537124, upload-time = "2025-03-07T01:43:53.556Z" }, { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, ] @@ -869,6 +874,7 @@ name = "nvidia-cuda-cupti-cu12" version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/b3bd73445e5cb342727fd24fe1f7b748f690b460acadc27ea22f904502c8/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed", size = 9533318, upload-time = "2025-03-07T01:40:10.421Z" }, { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, ] @@ -878,6 +884,7 @@ version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, + { url = "https://files.pythonhosted.org/packages/eb/d1/e50d0acaab360482034b84b6e27ee83c6738f7d32182b987f9c7a4e32962/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8", size = 43106076, upload-time = "2025-03-07T01:41:59.817Z" }, ] [[package]] @@ -885,6 +892,7 @@ name = "nvidia-cuda-runtime-cu12" version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/75/f865a3b236e4647605ea34cc450900854ba123834a5f1598e160b9530c3a/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d", size = 965265, upload-time = "2025-03-07T01:39:43.533Z" }, { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, ] @@ -896,6 +904,7 @@ dependencies = [ { name = "nvidia-cublas-cu12" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" }, { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, ] @@ -907,6 +916,7 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, ] @@ -916,6 +926,7 @@ version = "1.13.1.3" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, + { url = "https://files.pythonhosted.org/packages/1e/f5/5607710447a6fe9fd9b3283956fceeee8a06cda1d2f56ce31371f595db2a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a", size = 1120705, upload-time = "2025-03-07T01:45:41.434Z" }, ] [[package]] @@ -923,6 +934,7 @@ name = "nvidia-curand-cu12" version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/45/5e/92aa15eca622a388b80fbf8375d4760738df6285b1e92c43d37390a33a9a/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd", size = 63625754, upload-time = "2025-03-07T01:46:10.735Z" }, { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, ] @@ -936,6 +948,7 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, ] @@ -947,6 +960,7 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, ] @@ -955,6 +969,7 @@ name = "nvidia-cusparselt-cu12" version = "0.7.1" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/73/b9/598f6ff36faaece4b3c50d26f50e38661499ff34346f00e057760b35cc9d/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5", size = 283835557, upload-time = "2025-02-26T00:16:54.265Z" }, { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, ] @@ -963,6 +978,7 @@ name = "nvidia-nccl-cu12" version = "2.27.5" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/1c/857979db0ef194ca5e21478a0612bcdbbe59458d7694361882279947b349/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:31432ad4d1fb1004eb0c56203dc9bc2178a1ba69d1d9e02d64a6938ab5e40e7a", size = 322400625, upload-time = "2025-06-26T04:11:04.496Z" }, { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, ] @@ -972,6 +988,7 @@ version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, + { url = "https://files.pythonhosted.org/packages/2a/a2/8cee5da30d13430e87bf99bb33455d2724d0a4a9cb5d7926d80ccb96d008/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7", size = 38386204, upload-time = "2025-03-07T01:49:43.612Z" }, ] [[package]] @@ -979,6 +996,7 @@ name = "nvidia-nvshmem-cu12" version = "3.4.5" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/6a/03aa43cc9bd3ad91553a88b5f6fb25ed6a3752ae86ce2180221962bc2aa5/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15", size = 138936938, upload-time = "2025-09-06T00:32:05.589Z" }, { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, ] @@ -987,6 +1005,7 @@ name = "nvidia-nvtx-cu12" version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/10/c0/1b303feea90d296f6176f32a2a70b5ef230f9bdeb3a72bddb0dc922dc137/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615", size = 91161, upload-time = "2025-03-07T01:42:23.922Z" }, { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, ] @@ -1393,14 +1412,14 @@ wheels = [ [[package]] name = "sphinx-autodoc-typehints" -version = "3.9.5" +version = "3.9.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sphinx" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/58/ec/21bd9babcfeb9930a73011257002d5cfa5fd30667b8de6d76dbaf8275dfb/sphinx_autodoc_typehints-3.9.5.tar.gz", hash = "sha256:60e646efb7c352a0e98f34dd7fdcde4527fbbdbdf30371ff8321b6b3eb1fd37d", size = 63249, upload-time = "2026-03-02T19:58:07.974Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/06/da2d9e98b3f7f0df144496e62f453e0025f129bccc7a6076b8ceae6047b1/sphinx_autodoc_typehints-3.9.7.tar.gz", hash = "sha256:70f3dd4e4dd815ae30e5d3848a26dca71fb5e7fcf8f37cf8b840dc8afdf07e82", size = 68689, upload-time = "2026-03-05T18:33:40.829Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/cb/80c250f47a0ca5ac67d82f14811b4068a551a12b4790b085ffdb900de427/sphinx_autodoc_typehints-3.9.5-py3-none-any.whl", hash = "sha256:c94f88a90b6c61a7a6686cb77b410e46e077712838387e6cf22d69e85cfd06a5", size = 34763, upload-time = "2026-03-02T19:58:06.028Z" }, + { url = "https://files.pythonhosted.org/packages/a4/a0/e7d3365dabfa79a1b2ac7d3122b5b22b401a9c4d5e4eadc5e13b88c63a2c/sphinx_autodoc_typehints-3.9.7-py3-none-any.whl", hash = "sha256:dd73f6a32adef0d8208f6f7d99254e1880259c77db7b4a91648345d45202d48e", size = 36691, upload-time = "2026-03-05T18:33:38.983Z" }, ] [[package]] @@ -1542,52 +1561,47 @@ wheels = [ [[package]] name = "torch" -version = "2.10.0" -source = { registry = "https://pypi.org/simple" } +version = "2.10.0+cu128" +source = { registry = "https://download.pytorch.org/whl/cu128" } dependencies = [ - { name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux'" }, { name = "setuptools" }, { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, - { url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, - { url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, - { url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, - { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, - { url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, - { url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, - { url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, - { url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, - { url = "https://files.pythonhosted.org/packages/4f/93/716b5ac0155f1be70ed81bacc21269c3ece8dba0c249b9994094110bfc51/torch-2.10.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:bf0d9ff448b0218e0433aeb198805192346c4fd659c852370d5cc245f602a06a", size = 79464992, upload-time = "2026-01-21T16:23:05.162Z" }, - { url = "https://files.pythonhosted.org/packages/69/2b/51e663ff190c9d16d4a8271203b71bc73a16aa7619b9f271a69b9d4a936b/torch-2.10.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:233aed0659a2503b831d8a67e9da66a62c996204c0bba4f4c442ccc0c68a3f60", size = 146018567, upload-time = "2026-01-21T16:22:23.393Z" }, - { url = "https://files.pythonhosted.org/packages/5e/cd/4b95ef7f293b927c283db0b136c42be91c8ec6845c44de0238c8c23bdc80/torch-2.10.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:682497e16bdfa6efeec8cde66531bc8d1fbbbb4d8788ec6173c089ed3cc2bfe5", size = 915721646, upload-time = "2026-01-21T16:21:16.983Z" }, - { url = "https://files.pythonhosted.org/packages/56/97/078a007208f8056d88ae43198833469e61a0a355abc0b070edd2c085eb9a/torch-2.10.0-cp314-cp314-win_amd64.whl", hash = "sha256:6528f13d2a8593a1a412ea07a99812495bec07e9224c28b2a25c0a30c7da025c", size = 113752373, upload-time = "2026-01-21T16:22:13.471Z" }, - { url = "https://files.pythonhosted.org/packages/d8/94/71994e7d0d5238393df9732fdab607e37e2b56d26a746cb59fdb415f8966/torch-2.10.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:f5ab4ba32383061be0fb74bda772d470140a12c1c3b58a0cfbf3dae94d164c28", size = 79850324, upload-time = "2026-01-21T16:22:09.494Z" }, - { url = "https://files.pythonhosted.org/packages/e2/65/1a05346b418ea8ccd10360eef4b3e0ce688fba544e76edec26913a8d0ee0/torch-2.10.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:716b01a176c2a5659c98f6b01bf868244abdd896526f1c692712ab36dbaf9b63", size = 146006482, upload-time = "2026-01-21T16:22:18.42Z" }, - { url = "https://files.pythonhosted.org/packages/1d/b9/5f6f9d9e859fc3235f60578fa64f52c9c6e9b4327f0fe0defb6de5c0de31/torch-2.10.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:d8f5912ba938233f86361e891789595ff35ca4b4e2ac8fe3670895e5976731d6", size = 915613050, upload-time = "2026-01-21T16:20:49.035Z" }, - { url = "https://files.pythonhosted.org/packages/66/4d/35352043ee0eaffdeff154fad67cd4a31dbed7ff8e3be1cc4549717d6d51/torch-2.10.0-cp314-cp314t-win_amd64.whl", hash = "sha256:71283a373f0ee2c89e0f0d5f446039bdabe8dbc3c9ccf35f0f784908b0acd185", size = 113995816, upload-time = "2026-01-21T16:22:05.312Z" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:bdbcc703382f948e951c063448c9406bf38ce66c41dd698d9e2733fcf96c037a" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:7b4bd23ed63de97456fcc81c26fea9f02ee02ce1112111c4dac0d8cfe574b23e" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313-win_amd64.whl", hash = "sha256:4d1b0b49c54223c7c04050b49eac141d77b6edbc34aea1dfc74a6fdb661baa8c" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:f1f8b840c64b645a4bc61a393db48effb9c92b2dc26c8373873911f0750d1ea7" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:23f58258012bcf1c349cb22af387e33aadca7f83ea617b080e774eb41e4fe8ff" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp313-cp313t-win_amd64.whl", hash = "sha256:01b216e097b17a5277cfb47c383cdcacf06abeadcb0daca0c76b59e72854c3b6" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:c42377bc2607e3e1c60da71b792fb507c3938c87fd6edab8b21c59c91473c36d" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:37d71feea068776855686a1512058df3f19f6f040a151f055aa746601678744f" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314-win_amd64.whl", hash = "sha256:c57017ca29e62271e362fdeee7d20070e254755a5148b30b553d8a10fc83c7ef" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:777461f50b2daf77e4bdd8e2ad34bdfc5a993bf1bdf2ab9ef39f5edfe4e9c12b" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7bcba6a7c5f0987a13298b1ca843155dcceceac758fa3c7ccd5c7af4059a1080" }, + { url = "https://download.pytorch.org/whl/cu128/torch-2.10.0%2Bcu128-cp314-cp314t-win_amd64.whl", hash = "sha256:70d89143c956389d4806cb4e5fe0b1129fe0db280e1073288d17fa76c101cba4" }, ] [[package]] @@ -1623,7 +1637,7 @@ wheels = [ [[package]] name = "trainlib" -version = "0.1.0" +version = "0.1.1" source = { editable = "." } dependencies = [ { name = "colorama" }, @@ -1662,7 +1676,7 @@ requires-dist = [ { name = "sphinx-autodoc-typehints", marker = "extra == 'doc'" }, { name = "sphinx-togglebutton", marker = "extra == 'doc'" }, { name = "tensorboard", specifier = ">=2.20.0" }, - { name = "torch", specifier = ">=2.5.1" }, + { name = "torch", index = "https://download.pytorch.org/whl/cu128" }, { name = "tqdm", specifier = ">=4.67.1" }, ] provides-extras = ["dev", "doc", "test"] @@ -1681,9 +1695,13 @@ name = "triton" version = "3.6.0" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" }, { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" }, { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, + { url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" }, { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, + { url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" }, { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, ]