Files
trainlib/trainlib/utils/type.py

77 lines
2.0 KiB
Python

from typing import Any, TypedDict
from collections.abc import Callable, Iterable
import numpy as np
from torch import Tensor
from torch.utils.data.sampler import Sampler
from trainlib.dataset import BatchedDataset
# need b/c matplotlib axes are insanely stupid
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
class LoaderKwargs(TypedDict, total=False):
batch_size: int
shuffle: bool
sampler: Sampler | Iterable | None
batch_sampler: Sampler[list] | Iterable[list] | None
num_workers: int
collate_fn: Callable[[list], Any]
pin_memory: bool
drop_last: bool
timeout: float
worker_init_fn: Callable[[int], None]
multiprocessing_context: object
generator: object
prefetch_factor: int
persistent_workers: bool
pin_memory_device: str
in_order: bool
class SplitKwargs(TypedDict, total=False):
dataset: BatchedDataset | None
by_attr: str | list[str | None] | None
shuffle_strata: bool
class BalanceKwargs(TypedDict, total=False):
by_attr: str | list[str | None] | None
split_min_sizes: list[int] | None
split_max_sizes: list[int] | None
shuffle_strata: bool
class OptimizerKwargs(TypedDict, total=False):
lr: float | Tensor
betas: tuple[float | Tensor, float | Tensor]
eps: float
weight_decay: float
amsgrad: bool
maximize: bool
foreach: bool | None
capturable: bool
differentiable: bool
fused: bool | None
class SubplotsKwargs(TypedDict, total=False):
sharex: bool | str
sharey: bool | str
squeeze: bool
width_ratios: list[float]
height_ratios: list[float]
subplot_kw: dict
gridspec_kw: dict
figsize: tuple[float, float]
dpi: float
layout: str
sharex: bool | Literal["none", "all", "row", "col"] = False,
sharey: bool | Literal["none", "all", "row", "col"] = False,
squeeze: bool = True,
width_ratios: Sequence[float] | None = None,
height_ratios: Sequence[float] | None = None,
subplot_kw: dict[str, Any] | None = None,
gridspec_kw: dict[str, Any] | None = None,