77 lines
2.0 KiB
Python
77 lines
2.0 KiB
Python
from typing import Any, TypedDict
|
|
from collections.abc import Callable, Iterable
|
|
|
|
import numpy as np
|
|
from torch import Tensor
|
|
from torch.utils.data.sampler import Sampler
|
|
|
|
from trainlib.dataset import BatchedDataset
|
|
|
|
# need b/c matplotlib axes are insanely stupid
|
|
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
|
|
|
|
class LoaderKwargs(TypedDict, total=False):
|
|
batch_size: int
|
|
shuffle: bool
|
|
sampler: Sampler | Iterable | None
|
|
batch_sampler: Sampler[list] | Iterable[list] | None
|
|
num_workers: int
|
|
collate_fn: Callable[[list], Any]
|
|
pin_memory: bool
|
|
drop_last: bool
|
|
timeout: float
|
|
worker_init_fn: Callable[[int], None]
|
|
multiprocessing_context: object
|
|
generator: object
|
|
prefetch_factor: int
|
|
persistent_workers: bool
|
|
pin_memory_device: str
|
|
in_order: bool
|
|
|
|
|
|
class SplitKwargs(TypedDict, total=False):
|
|
dataset: BatchedDataset | None
|
|
by_attr: str | list[str | None] | None
|
|
shuffle_strata: bool
|
|
|
|
|
|
class BalanceKwargs(TypedDict, total=False):
|
|
by_attr: str | list[str | None] | None
|
|
split_min_sizes: list[int] | None
|
|
split_max_sizes: list[int] | None
|
|
shuffle_strata: bool
|
|
|
|
|
|
class OptimizerKwargs(TypedDict, total=False):
|
|
lr: float | Tensor
|
|
betas: tuple[float | Tensor, float | Tensor]
|
|
eps: float
|
|
weight_decay: float
|
|
amsgrad: bool
|
|
maximize: bool
|
|
foreach: bool | None
|
|
capturable: bool
|
|
differentiable: bool
|
|
fused: bool | None
|
|
|
|
|
|
class SubplotsKwargs(TypedDict, total=False):
|
|
sharex: bool | str
|
|
sharey: bool | str
|
|
squeeze: bool
|
|
width_ratios: list[float]
|
|
height_ratios: list[float]
|
|
subplot_kw: dict
|
|
gridspec_kw: dict
|
|
figsize: tuple[float, float]
|
|
dpi: float
|
|
layout: str
|
|
|
|
sharex: bool | Literal["none", "all", "row", "col"] = False,
|
|
sharey: bool | Literal["none", "all", "row", "col"] = False,
|
|
squeeze: bool = True,
|
|
width_ratios: Sequence[float] | None = None,
|
|
height_ratios: Sequence[float] | None = None,
|
|
subplot_kw: dict[str, Any] | None = None,
|
|
gridspec_kw: dict[str, Any] | None = None,
|