From 95d7bc68ce6ca7f660f13b951c0bb54bfdaca0bf Mon Sep 17 00:00:00 2001 From: smgr Date: Tue, 10 Mar 2026 02:39:27 -0700 Subject: [PATCH] add MLP estimator, update Estimator generics --- trainlib/datasets/memory.py | 8 +- trainlib/domain.py | 11 +++ trainlib/domains/functional.py | 14 +++- trainlib/estimators/mlp.py | 149 +++++++++++++++++++++++++++++++++ trainlib/estimators/rnn.py | 1 - trainlib/trainer.py | 24 ++++-- 6 files changed, 189 insertions(+), 18 deletions(-) create mode 100644 trainlib/estimators/mlp.py diff --git a/trainlib/datasets/memory.py b/trainlib/datasets/memory.py index c1e9adf..42059f0 100644 --- a/trainlib/datasets/memory.py +++ b/trainlib/datasets/memory.py @@ -80,10 +80,10 @@ from trainlib.domain import SequenceDomain from trainlib.dataset import TupleDataset, DatasetKwargs -class SlidingWindowDataset[T: Tensor](TupleDataset[T]): +class SlidingWindowDataset(TupleDataset[Tensor]): def __init__( self, - domain: SequenceDomain[tuple[T, ...]], + domain: SequenceDomain[tuple[Tensor, ...]], lookback: int, offset: int = 0, lookahead: int = 1, @@ -99,9 +99,9 @@ class SlidingWindowDataset[T: Tensor](TupleDataset[T]): def _process_batch_data( self, - batch_data: tuple[T, ...], + batch_data: tuple[Tensor, ...], batch_index: int, - ) -> list[tuple[T, ...]]: + ) -> list[tuple[Tensor, ...]]: """ Backward pads first sequence over (lookback-1) length, and steps the remaining items forward by the lookahead. diff --git a/trainlib/domain.py b/trainlib/domain.py index 00d77bf..84ad5bb 100644 --- a/trainlib/domain.py +++ b/trainlib/domain.py @@ -64,3 +64,14 @@ class SequenceDomain[R](Domain[int, R]): def __len__(self) -> int: return len(self.sequence) + + +class TupleDomain[T](SequenceDomain[tuple[T, ...]]): + """ + Domain for homogenous tuples of the same type. + + This class header exists primarily as typed alias that aligns with + TupleDataset. + """ + + ... diff --git a/trainlib/domains/functional.py b/trainlib/domains/functional.py index 5ce5ba3..7c94860 100644 --- a/trainlib/domains/functional.py +++ b/trainlib/domains/functional.py @@ -8,10 +8,16 @@ class SimulatorDomain[P, R](Domain[int, R]): Base simulator domain, generic to arbitrary callables. Note: we don't store simulation results here; that's left to a downstream - object, like a `BatchedDataset`, to cache if needed. We also don't subclass - `SequenceDataset` because the item getter type doesn't align: we accept an - `int` in the parameter list, but don't return the items directly from that - collection (we transform them first). + object, like a ``BatchedDataset``, to cache if needed. We also don't + subclass ``SequenceDataset`` because the item getter type doesn't align: we + accept an ``int`` in the parameter list, but don't return the items + directly from that collection (we transform them first). + + Note: it's interesting to consider the idea of having parameters directly + act as URIs. There is, however, no obvious way to iterate over allowed + parameters (without additional components, like a prior or some other + generator), so we leave that outside the class scope and simply operate + over of a provided parameter sequence. """ def __init__( diff --git a/trainlib/estimators/mlp.py b/trainlib/estimators/mlp.py new file mode 100644 index 0000000..27edf63 --- /dev/null +++ b/trainlib/estimators/mlp.py @@ -0,0 +1,149 @@ +import logging +from typing import Unpack, NotRequired +from collections.abc import Callable, Generator + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter + +from trainlib.estimator import Estimator, EstimatorKwargs +from trainlib.utils.type import OptimizerKwargs +from trainlib.utils.module import get_grad_norm +from trainlib.estimators.tdnn import TDNNLayer + +logger: logging.Logger = logging.getLogger(__name__) + + +class MLPKwargs(EstimatorKwargs): + inputs: Tensor + labels: NotRequired[Tensor] + + +class MLP[K: MLPKwargs](Estimator[K]): + """ + Base MLP architecture. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: list[int] | None = None, + norm_layer: Callable[..., nn.Module] | None = None, + activation_fn: nn.Module | None = None, + inplace: bool = False, + bias: bool = True, + dropout: float = 0.0, + verbose: bool = True, + ) -> None: + """ + Parameters: + input_dim: dimensionality of the input + output_dim: dimensionality of the output + hidden_dims: dimensionalities of hidden layers + """ + + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.hidden_dims = hidden_dims or [] + self.norm_layer = norm_layer + self.activation_fn = activation_fn or nn.ReLU + + # self._layers: nn.ModuleList = nn.ModuleList() + self._layers = [] + + layer_in_dim = input_dim + for layer_out_dim in self.hidden_dims: + hidden_layer = nn.Linear(layer_in_dim, layer_out_dim, bias=bias) + self._layers.append(hidden_layer) + + if norm_layer is not None: + self._layers.append(norm_layer(layer_out_dim)) + + self._layers.append(self.activation_fn(inplace=inplace)) + self._layers.append(nn.Dropout(dropout, inplace=inplace)) + + layer_in_dim = layer_out_dim + + self._layers.append(nn.Linear(layer_in_dim, self.output_dim)) + self._net = nn.Sequential(*self._layers) + + if verbose: + self.log_arch() + + def _clamp_rand(self, x: Tensor) -> Tensor: + return torch.clamp( + x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5), + min=-1.0, + max=1.0, + ) + + def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]: + inputs = kwargs["inputs"] + x = self._net(inputs) + + return (x,) + + def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]: + predictions = self(**kwargs)[0] + labels = kwargs["labels"] + + yield F.mse_loss(predictions, labels) + + def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]: + with torch.no_grad(): + loss = next(self.loss(**kwargs)).item() + + predictions = self(**kwargs)[0] + labels = kwargs["labels"] + mae = F.l1_loss(predictions, labels).item() + + return { + "loss": loss, + "mse": loss, + "mae": mae, + "grad_norm": get_grad_norm(self) + } + + def optimizers( + self, + **kwargs: Unpack[OptimizerKwargs], + ) -> tuple[Optimizer, ...]: + """ + """ + + default_kwargs: Unpack[OptimizerKwargs] = { + "lr": 1e-3, + "eps": 1e-8, + } + opt_kwargs = {**default_kwargs, **kwargs} + + optimizer = torch.optim.AdamW( + self.parameters(), + **opt_kwargs, + ) + + return (optimizer,) + + def epoch_step(self) -> None: + return None + + def epoch_write( + self, + writer: SummaryWriter, + step: int | None = None, + val: bool = False, + **kwargs: Unpack[K], + ) -> None: + return None + + def log_arch(self) -> None: + super().log_arch() + + logger.info(f"| > {self.input_dim=}") + logger.info(f"| > {self.hidden_dims=}") + logger.info(f"| > {self.output_dim=}") diff --git a/trainlib/estimators/rnn.py b/trainlib/estimators/rnn.py index 8ff620b..b471753 100644 --- a/trainlib/estimators/rnn.py +++ b/trainlib/estimators/rnn.py @@ -458,7 +458,6 @@ class ConvGRU[K: RNNKwargs](Estimator[K]): "grad_norm": get_grad_norm(self) } - def optimizers( self, **kwargs: Unpack[OptimizerKwargs], diff --git a/trainlib/trainer.py b/trainlib/trainer.py index 649097e..d24cc84 100644 --- a/trainlib/trainer.py +++ b/trainlib/trainer.py @@ -258,10 +258,10 @@ class Trainer[I, K: EstimatorKwargs]: return val_loss_sums - def train( + def train[B]( self, dataset: BatchedDataset[..., ..., I], - batch_estimator_map: Callable[[I, Self], K], + batch_estimator_map: Callable[[B, Self], K], lr: float = 1e-3, eps: float = 1e-8, max_grad_norm: float | None = None, @@ -280,6 +280,10 @@ class Trainer[I, K: EstimatorKwargs]: summary_writer: SummaryWriter | None = None, ) -> Estimator: """ + TODO: consider making the dataloader ``collate_fn`` an explicit + parameter with a type signature that reflects ``B``, connecting the + ``batch_estimator_map`` somewhere + Note: this method attempts to implement a general scheme for passing needed items to the estimator's loss function from the dataloader. The abstract ``Estimator`` base only requires the model output be provided @@ -289,7 +293,7 @@ class Trainer[I, K: EstimatorKwargs]: further logic to the ``loss`` method of the underlying estimator, so one should take care to synchronize the sample structure with `dataset` to match that expected by ``self.estimator.loss(...)``. - + .. admonition:: On ``batch_estimator_map`` Dataloader collate functions are responsible for mapping a @@ -306,7 +310,7 @@ class Trainer[I, K: EstimatorKwargs]: the collate function maps back into the item skeleton, producing a single tuple of (stacked) tensors - + .. code-block:: text ( [[1, 1], @@ -317,13 +321,15 @@ class Trainer[I, K: EstimatorKwargs]: [2, 2], [3, 3]] ) - This function should map from batches (which should be *item - shaped*, i.e., have an ``I`` skeleton, even if stacked items may be - different on the inside) into estimator keyword arguments (type - ``K``). + This function should map from batches - which *may* be item + shaped, i.e., have an ``I`` skeleton, even if stacked items may be + different on the inside - into estimator keyword arguments (type + ``K``). Collation behavior from a DataLoader (which can be + customized) doesn't consistently yield a known type shape, however, + so it's not appropriate to use ``I`` as the callable param type. Parameters: - dataset: dataset to train the estimator + dataset: dataset to train the estimator batch_estimator_map: function mapping from batch data to expected estimator kwargs lr: learning rate (default: 1e-3)