diff --git a/trainlib/datasets/disk.py b/trainlib/datasets/disk.py index e7a339e..de4d326 100644 --- a/trainlib/datasets/disk.py +++ b/trainlib/datasets/disk.py @@ -4,8 +4,8 @@ from typing import Any, NamedTuple from pathlib import Path from zipfile import ZipFile -from mema.dataset import HomogenousDataset -from mema.domains.disk import DiskDomain +from trainlib.dataset import HomogenousDataset +from trainlib.domains.disk import DiskDomain class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]): diff --git a/trainlib/domains/disk.py b/trainlib/domains/disk.py index 384b721..4c52838 100644 --- a/trainlib/domains/disk.py +++ b/trainlib/domains/disk.py @@ -1,7 +1,7 @@ from pathlib import Path from collections.abc import Iterator -from mema.domain import Domain +from trainlib.domain import Domain class DiskDomain(Domain[Path, bytes]): diff --git a/trainlib/estimators/rnn.py b/trainlib/estimators/rnn.py index 0c0b4fa..0ca7ce6 100644 --- a/trainlib/estimators/rnn.py +++ b/trainlib/estimators/rnn.py @@ -8,10 +8,10 @@ from torch import nn, Tensor from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from mema.estimator import Estimator, EstimatorKwargs -from mema.util.type import OptimizerKwargs -from mema.util.module import get_grad_norm -from mema.estimators.tdnn import TDNNLayer +from trainlib.estimator import Estimator, EstimatorKwargs +from trainlib.utils.type import OptimizerKwargs +from trainlib.utils.module import get_grad_norm +from trainlib.estimators.tdnn import TDNNLayer logger: logging.Logger = logging.getLogger(__name__) diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 3450b2f..c96f174 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -18,12 +18,12 @@ from torch.utils.tensorboard import SummaryWriter from trainlib.dataset import BatchedDataset from trainlib.estimator import Estimator, EstimatorKwargs from trainlib.transform import Transform -from trainlib.util.type import ( +from trainlib.utils.type import ( SplitKwargs, LoaderKwargs, BalanceKwargs, ) -from trainlib.util.module import ModelWrapper +from trainlib.utils.module import ModelWrapper logger: logging.Logger = logging.getLogger(__name__) diff --git a/trainlib/utils/job.py b/trainlib/utils/job.py index 7f4af9e..dad0979 100644 --- a/trainlib/utils/job.py +++ b/trainlib/utils/job.py @@ -5,7 +5,7 @@ from concurrent.futures import Future, as_completed from tqdm import tqdm from colorama import Fore, Style -from mema.util.text import color_text +from trainlib.utils.text import color_text logger: logging.Logger = logging.getLogger(__name__) @@ -15,6 +15,15 @@ def process_futures( desc: str | None = None, unit: str | None = None, ) -> None: + """ + Wait on futures results in a blocking loop, showing progress with TQDM. + + Parameters: + futures: list of futures to wait on + desc: description for the TQDM bar + unit: unit to display for loops in the TQDM bar + """ + if desc is None: desc = "Awaiting futures" diff --git a/trainlib/utils/type.py b/trainlib/utils/type.py index 64d81b1..c0489cb 100644 --- a/trainlib/utils/type.py +++ b/trainlib/utils/type.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable from torch import Tensor from torch.utils.data.sampler import Sampler -from mema.dataset import BatchedDataset +from trainlib.dataset import BatchedDataset class LoaderKwargs(TypedDict, total=False):