update package internal references
This commit is contained in:
@@ -4,8 +4,8 @@ from typing import Any, NamedTuple
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
from mema.dataset import HomogenousDataset
|
||||
from mema.domains.disk import DiskDomain
|
||||
from trainlib.dataset import HomogenousDataset
|
||||
from trainlib.domains.disk import DiskDomain
|
||||
|
||||
|
||||
class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterator
|
||||
|
||||
from mema.domain import Domain
|
||||
from trainlib.domain import Domain
|
||||
|
||||
|
||||
class DiskDomain(Domain[Path, bytes]):
|
||||
|
||||
@@ -8,10 +8,10 @@ from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from mema.estimator import Estimator, EstimatorKwargs
|
||||
from mema.util.type import OptimizerKwargs
|
||||
from mema.util.module import get_grad_norm
|
||||
from mema.estimators.tdnn import TDNNLayer
|
||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||
from trainlib.utils.type import OptimizerKwargs
|
||||
from trainlib.utils.module import get_grad_norm
|
||||
from trainlib.estimators.tdnn import TDNNLayer
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from trainlib.dataset import BatchedDataset
|
||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||
from trainlib.transform import Transform
|
||||
from trainlib.util.type import (
|
||||
from trainlib.utils.type import (
|
||||
SplitKwargs,
|
||||
LoaderKwargs,
|
||||
BalanceKwargs,
|
||||
)
|
||||
from trainlib.util.module import ModelWrapper
|
||||
from trainlib.utils.module import ModelWrapper
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from concurrent.futures import Future, as_completed
|
||||
from tqdm import tqdm
|
||||
from colorama import Fore, Style
|
||||
|
||||
from mema.util.text import color_text
|
||||
from trainlib.utils.text import color_text
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,6 +15,15 @@ def process_futures(
|
||||
desc: str | None = None,
|
||||
unit: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Wait on futures results in a blocking loop, showing progress with TQDM.
|
||||
|
||||
Parameters:
|
||||
futures: list of futures to wait on
|
||||
desc: description for the TQDM bar
|
||||
unit: unit to display for loops in the TQDM bar
|
||||
"""
|
||||
|
||||
if desc is None:
|
||||
desc = "Awaiting futures"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable
|
||||
from torch import Tensor
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from mema.dataset import BatchedDataset
|
||||
from trainlib.dataset import BatchedDataset
|
||||
|
||||
|
||||
class LoaderKwargs(TypedDict, total=False):
|
||||
|
||||
Reference in New Issue
Block a user