update package internal references

This commit is contained in:
2026-03-03 18:23:11 -08:00
parent 337175d428
commit c473e48b5b
6 changed files with 20 additions and 11 deletions

View File

@@ -4,8 +4,8 @@ from typing import Any, NamedTuple
from pathlib import Path
from zipfile import ZipFile
from mema.dataset import HomogenousDataset
from mema.domains.disk import DiskDomain
from trainlib.dataset import HomogenousDataset
from trainlib.domains.disk import DiskDomain
class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):

View File

@@ -1,7 +1,7 @@
from pathlib import Path
from collections.abc import Iterator
from mema.domain import Domain
from trainlib.domain import Domain
class DiskDomain(Domain[Path, bytes]):

View File

@@ -8,10 +8,10 @@ from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from mema.estimator import Estimator, EstimatorKwargs
from mema.util.type import OptimizerKwargs
from mema.util.module import get_grad_norm
from mema.estimators.tdnn import TDNNLayer
from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.type import OptimizerKwargs
from trainlib.utils.module import get_grad_norm
from trainlib.estimators.tdnn import TDNNLayer
logger: logging.Logger = logging.getLogger(__name__)

View File

@@ -18,12 +18,12 @@ from torch.utils.tensorboard import SummaryWriter
from trainlib.dataset import BatchedDataset
from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.transform import Transform
from trainlib.util.type import (
from trainlib.utils.type import (
SplitKwargs,
LoaderKwargs,
BalanceKwargs,
)
from trainlib.util.module import ModelWrapper
from trainlib.utils.module import ModelWrapper
logger: logging.Logger = logging.getLogger(__name__)

View File

@@ -5,7 +5,7 @@ from concurrent.futures import Future, as_completed
from tqdm import tqdm
from colorama import Fore, Style
from mema.util.text import color_text
from trainlib.utils.text import color_text
logger: logging.Logger = logging.getLogger(__name__)
@@ -15,6 +15,15 @@ def process_futures(
desc: str | None = None,
unit: str | None = None,
) -> None:
"""
Wait on futures results in a blocking loop, showing progress with TQDM.
Parameters:
futures: list of futures to wait on
desc: description for the TQDM bar
unit: unit to display for loops in the TQDM bar
"""
if desc is None:
desc = "Awaiting futures"

View File

@@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable
from torch import Tensor
from torch.utils.data.sampler import Sampler
from mema.dataset import BatchedDataset
from trainlib.dataset import BatchedDataset
class LoaderKwargs(TypedDict, total=False):