Files
trainlib/trainlib/utils/type.py

70 lines
1.8 KiB
Python

from typing import Any, TypedDict
from collections.abc import Callable, Iterable, Sequence
import numpy as np
from torch import Tensor
from torch.utils.data.sampler import Sampler
from trainlib.dataset import BatchedDataset
# need b/c matplotlib axes are insanely stupid
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
class LoaderKwargs(TypedDict, total=False):
batch_size: int | None
shuffle: bool
sampler: Sampler | Iterable | None
batch_sampler: Sampler[list] | Iterable[list] | None
num_workers: int
collate_fn: Callable[[list], Any]
pin_memory: bool
drop_last: bool
timeout: float
worker_init_fn: Callable[[int], None]
multiprocessing_context: object
generator: object
prefetch_factor: int
persistent_workers: bool
pin_memory_device: str
in_order: bool
class SplitKwargs(TypedDict, total=False):
dataset: BatchedDataset | None
by_attr: str | list[str | None] | None
shuffle_strata: bool
class BalanceKwargs(TypedDict, total=False):
by_attr: str | list[str | None] | None
split_min_sizes: list[int] | None
split_max_sizes: list[int] | None
shuffle_strata: bool
class OptimizerKwargs(TypedDict, total=False):
lr: float | Tensor
betas: tuple[float | Tensor, float | Tensor]
eps: float
weight_decay: float
amsgrad: bool
maximize: bool
foreach: bool | None
capturable: bool
differentiable: bool
fused: bool | None
class SubplotsKwargs(TypedDict, total=False):
sharex: bool | str
sharey: bool | str
squeeze: bool
width_ratios: Sequence[float]
height_ratios: Sequence[float]
subplot_kw: dict[str, ...]
gridspec_kw: dict[str, ...]
figsize: tuple[float, float]
dpi: float
layout: str
constrained_layout: bool