update package internal references
This commit is contained in:
@@ -4,8 +4,8 @@ from typing import Any, NamedTuple
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
from mema.dataset import HomogenousDataset
|
from trainlib.dataset import HomogenousDataset
|
||||||
from mema.domains.disk import DiskDomain
|
from trainlib.domains.disk import DiskDomain
|
||||||
|
|
||||||
|
|
||||||
class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
|
class DiskDataset[T: NamedTuple](HomogenousDataset[Path, bytes, T]):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from mema.domain import Domain
|
from trainlib.domain import Domain
|
||||||
|
|
||||||
|
|
||||||
class DiskDomain(Domain[Path, bytes]):
|
class DiskDomain(Domain[Path, bytes]):
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from torch import nn, Tensor
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from mema.estimator import Estimator, EstimatorKwargs
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||||
from mema.util.type import OptimizerKwargs
|
from trainlib.utils.type import OptimizerKwargs
|
||||||
from mema.util.module import get_grad_norm
|
from trainlib.utils.module import get_grad_norm
|
||||||
from mema.estimators.tdnn import TDNNLayer
|
from trainlib.estimators.tdnn import TDNNLayer
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from trainlib.dataset import BatchedDataset
|
from trainlib.dataset import BatchedDataset
|
||||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||||
from trainlib.transform import Transform
|
from trainlib.transform import Transform
|
||||||
from trainlib.util.type import (
|
from trainlib.utils.type import (
|
||||||
SplitKwargs,
|
SplitKwargs,
|
||||||
LoaderKwargs,
|
LoaderKwargs,
|
||||||
BalanceKwargs,
|
BalanceKwargs,
|
||||||
)
|
)
|
||||||
from trainlib.util.module import ModelWrapper
|
from trainlib.utils.module import ModelWrapper
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from concurrent.futures import Future, as_completed
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
|
||||||
from mema.util.text import color_text
|
from trainlib.utils.text import color_text
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,6 +15,15 @@ def process_futures(
|
|||||||
desc: str | None = None,
|
desc: str | None = None,
|
||||||
unit: str | None = None,
|
unit: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Wait on futures results in a blocking loop, showing progress with TQDM.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
futures: list of futures to wait on
|
||||||
|
desc: description for the TQDM bar
|
||||||
|
unit: unit to display for loops in the TQDM bar
|
||||||
|
"""
|
||||||
|
|
||||||
if desc is None:
|
if desc is None:
|
||||||
desc = "Awaiting futures"
|
desc = "Awaiting futures"
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
from mema.dataset import BatchedDataset
|
from trainlib.dataset import BatchedDataset
|
||||||
|
|
||||||
|
|
||||||
class LoaderKwargs(TypedDict, total=False):
|
class LoaderKwargs(TypedDict, total=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user