149 lines
4.0 KiB
Python
149 lines
4.0 KiB
Python
import logging
|
|
from typing import Unpack, NotRequired
|
|
from collections.abc import Callable, Generator
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn, Tensor
|
|
from torch.optim import Optimizer
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
|
from trainlib.utils.type import OptimizerKwargs
|
|
from trainlib.utils.module import get_grad_norm
|
|
from trainlib.estimators.tdnn import TDNNLayer
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MLPKwargs(EstimatorKwargs):
|
|
inputs: Tensor
|
|
labels: NotRequired[Tensor]
|
|
|
|
|
|
class MLP[K: MLPKwargs](Estimator[K]):
|
|
"""
|
|
Base MLP architecture.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
hidden_dims: list[int] | None = None,
|
|
norm_layer: Callable[..., nn.Module] | None = None,
|
|
activation_fn: nn.Module | None = None,
|
|
inplace: bool = False,
|
|
bias: bool = True,
|
|
dropout: float = 0.0,
|
|
verbose: bool = True,
|
|
) -> None:
|
|
"""
|
|
Parameters:
|
|
input_dim: dimensionality of the input
|
|
output_dim: dimensionality of the output
|
|
hidden_dims: dimensionalities of hidden layers
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
self.hidden_dims = hidden_dims or []
|
|
self.norm_layer = norm_layer
|
|
self.activation_fn = activation_fn or nn.ReLU
|
|
|
|
# self._layers: nn.ModuleList = nn.ModuleList()
|
|
self._layers = []
|
|
|
|
layer_in_dim = input_dim
|
|
for layer_out_dim in self.hidden_dims:
|
|
hidden_layer = nn.Linear(layer_in_dim, layer_out_dim, bias=bias)
|
|
self._layers.append(hidden_layer)
|
|
|
|
if norm_layer is not None:
|
|
self._layers.append(norm_layer(layer_out_dim))
|
|
|
|
self._layers.append(self.activation_fn(inplace=inplace))
|
|
self._layers.append(nn.Dropout(dropout, inplace=inplace))
|
|
|
|
layer_in_dim = layer_out_dim
|
|
|
|
self._layers.append(nn.Linear(layer_in_dim, self.output_dim))
|
|
self._net = nn.Sequential(*self._layers)
|
|
|
|
if verbose:
|
|
self.log_arch()
|
|
|
|
def _clamp_rand(self, x: Tensor) -> Tensor:
|
|
return torch.clamp(
|
|
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
|
min=-1.0,
|
|
max=1.0,
|
|
)
|
|
|
|
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
|
inputs = kwargs["inputs"]
|
|
x = self._net(inputs)
|
|
|
|
return (x,)
|
|
|
|
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
|
predictions = self(**kwargs)[0]
|
|
labels = kwargs["labels"]
|
|
|
|
yield F.mse_loss(predictions, labels)
|
|
|
|
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
|
with torch.no_grad():
|
|
loss = next(self.loss(**kwargs)).item()
|
|
|
|
predictions = self(**kwargs)[0]
|
|
labels = kwargs["labels"]
|
|
mae = F.l1_loss(predictions, labels).item()
|
|
|
|
return {
|
|
"mse": loss,
|
|
"mae": mae,
|
|
"grad_norm": get_grad_norm(self)
|
|
}
|
|
|
|
def optimizers(
|
|
self,
|
|
**kwargs: Unpack[OptimizerKwargs],
|
|
) -> tuple[Optimizer, ...]:
|
|
"""
|
|
"""
|
|
|
|
default_kwargs: Unpack[OptimizerKwargs] = {
|
|
"lr": 1e-3,
|
|
"eps": 1e-8,
|
|
}
|
|
opt_kwargs = {**default_kwargs, **kwargs}
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
self.parameters(),
|
|
**opt_kwargs,
|
|
)
|
|
|
|
return (optimizer,)
|
|
|
|
def epoch_step(self) -> None:
|
|
return None
|
|
|
|
def epoch_write(
|
|
self,
|
|
writer: SummaryWriter,
|
|
step: int | None = None,
|
|
val: bool = False,
|
|
**kwargs: Unpack[K],
|
|
) -> None:
|
|
return None
|
|
|
|
def log_arch(self) -> None:
|
|
super().log_arch()
|
|
|
|
logger.info(f"| > {self.input_dim=}")
|
|
logger.info(f"| > {self.hidden_dims=}")
|
|
logger.info(f"| > {self.output_dim=}")
|