from typing import Any, TypedDict from collections.abc import Callable, Iterable import numpy as np from torch import Tensor from torch.utils.data.sampler import Sampler from trainlib.dataset import BatchedDataset # need b/c matplotlib axes are insanely stupid type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]] class LoaderKwargs(TypedDict, total=False): batch_size: int shuffle: bool sampler: Sampler | Iterable | None batch_sampler: Sampler[list] | Iterable[list] | None num_workers: int collate_fn: Callable[[list], Any] pin_memory: bool drop_last: bool timeout: float worker_init_fn: Callable[[int], None] multiprocessing_context: object generator: object prefetch_factor: int persistent_workers: bool pin_memory_device: str in_order: bool class SplitKwargs(TypedDict, total=False): dataset: BatchedDataset | None by_attr: str | list[str | None] | None shuffle_strata: bool class BalanceKwargs(TypedDict, total=False): by_attr: str | list[str | None] | None split_min_sizes: list[int] | None split_max_sizes: list[int] | None shuffle_strata: bool class OptimizerKwargs(TypedDict, total=False): lr: float | Tensor betas: tuple[float | Tensor, float | Tensor] eps: float weight_decay: float amsgrad: bool maximize: bool foreach: bool | None capturable: bool differentiable: bool fused: bool | None class SubplotsKwargs(TypedDict, total=False): sharex: bool | str sharey: bool | str squeeze: bool width_ratios: list[float] height_ratios: list[float] subplot_kw: dict gridspec_kw: dict figsize: tuple[float, float] dpi: float layout: str sharex: bool | Literal["none", "all", "row", "col"] = False, sharey: bool | Literal["none", "all", "row", "col"] = False, squeeze: bool = True, width_ratios: Sequence[float] | None = None, height_ratios: Sequence[float] | None = None, subplot_kw: dict[str, Any] | None = None, gridspec_kw: dict[str, Any] | None = None,