499 lines
14 KiB
Python
499 lines
14 KiB
Python
import logging
|
|
from typing import Unpack, NotRequired
|
|
from collections.abc import Generator
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn, Tensor
|
|
from torch.optim import Optimizer
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
|
from trainlib.utils.type import OptimizerKwargs
|
|
from trainlib.utils.module import get_grad_norm
|
|
from trainlib.estimators.tdnn import TDNNLayer
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RNNKwargs(EstimatorKwargs):
|
|
inputs: Tensor
|
|
labels: NotRequired[Tensor]
|
|
|
|
|
|
class LSTM[Kw: RNNKwargs](Estimator[Kw]):
|
|
"""
|
|
Base RNN architecture.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
hidden_dim: int = 64,
|
|
num_layers: int = 4,
|
|
bidirectional: bool = False,
|
|
verbose: bool = True,
|
|
) -> None:
|
|
"""
|
|
Parameters:
|
|
input_dim: dimensionality of the input
|
|
output_dim: dimensionality of the output
|
|
rnn_dim: dimensionality of each RNN layer output
|
|
num_layers: number of LSTM layers pairs to use
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.num_layers = num_layers
|
|
|
|
self.dense_in = nn.Linear(input_dim, hidden_dim)
|
|
self.lstm = nn.LSTM(
|
|
hidden_dim,
|
|
hidden_dim,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
bidirectional=bidirectional,
|
|
)
|
|
|
|
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
|
|
self.dense_z = nn.Linear(lstm_out_dim, output_dim)
|
|
|
|
# weight initialization for LSTM layers
|
|
def init_weights(m: nn.Module) -> None:
|
|
if isinstance(m, nn.LSTM):
|
|
for name, p in m.named_parameters():
|
|
if "weight_ih" in name:
|
|
nn.init.xavier_uniform_(p)
|
|
elif "weight_hh" in name:
|
|
nn.init.orthogonal_(p)
|
|
elif "bias" in name:
|
|
nn.init.zeros_(p)
|
|
|
|
self.apply(init_weights)
|
|
|
|
if verbose:
|
|
self.log_arch()
|
|
|
|
def _clamp_rand(self, x: Tensor) -> Tensor:
|
|
return torch.clamp(
|
|
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
|
min=-1.0,
|
|
max=1.0,
|
|
)
|
|
|
|
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
|
inputs = kwargs["inputs"]
|
|
|
|
# data shaped (B, C, T); map to (B, T, C)
|
|
x = inputs.permute(0, 2, 1)
|
|
x = torch.tanh(self.dense_in(x))
|
|
x = self._clamp_rand(x)
|
|
x, hidden = self.lstm(x)
|
|
z = self.dense_z(x)
|
|
|
|
return z[:, -1, :], hidden
|
|
|
|
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
|
predictions = self(**kwargs)[0]
|
|
labels = kwargs["labels"]
|
|
|
|
yield F.mse_loss(predictions, labels)
|
|
|
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
|
with torch.no_grad():
|
|
loss = next(self.loss(**kwargs)).item()
|
|
|
|
predictions = self(**kwargs)[0]
|
|
labels = kwargs["labels"]
|
|
mae = F.l1_loss(predictions, labels).item()
|
|
|
|
return {
|
|
# "loss": loss,
|
|
"mse": loss,
|
|
"mae": mae,
|
|
"grad_norm": get_grad_norm(self)
|
|
}
|
|
|
|
def optimizers(
|
|
self,
|
|
**kwargs: Unpack[OptimizerKwargs],
|
|
) -> tuple[Optimizer, ...]:
|
|
"""
|
|
"""
|
|
|
|
default_kwargs: Unpack[OptimizerKwargs] = {
|
|
"lr": 1e-3,
|
|
"eps": 1e-8,
|
|
}
|
|
opt_kwargs = {**default_kwargs, **kwargs}
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
self.parameters(),
|
|
**opt_kwargs,
|
|
)
|
|
|
|
return (optimizer,)
|
|
|
|
def epoch_step(self) -> None:
|
|
return None
|
|
|
|
def epoch_write(
|
|
self,
|
|
writer: SummaryWriter,
|
|
step: int | None = None,
|
|
group: str | None = None,
|
|
**kwargs: Unpack[Kw],
|
|
) -> None:
|
|
return None
|
|
|
|
def log_arch(self) -> None:
|
|
super().log_arch()
|
|
|
|
logger.info(f"| > {self.input_dim=}")
|
|
logger.info(f"| > {self.hidden_dim=}")
|
|
logger.info(f"| > {self.num_layers=}")
|
|
logger.info(f"| > {self.output_dim=}")
|
|
|
|
|
|
class MultiheadLSTMKwargs(EstimatorKwargs):
|
|
inputs: Tensor
|
|
labels: NotRequired[Tensor]
|
|
auxiliary: NotRequired[Tensor]
|
|
|
|
|
|
class MultiheadLSTM[Kw: MultiheadLSTMKwargs](Estimator[Kw]):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
hidden_dim: int = 64,
|
|
num_layers: int = 4,
|
|
bidirectional: bool = False,
|
|
head_dims: list[int] | None = None,
|
|
verbose: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.num_layers = num_layers
|
|
self.head_dims = head_dims if head_dims is not None else []
|
|
|
|
self.dense_in = nn.Linear(input_dim, hidden_dim)
|
|
self.lstm = nn.LSTM(
|
|
hidden_dim,
|
|
hidden_dim,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
bidirectional=bidirectional,
|
|
)
|
|
|
|
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
|
|
self.dense_z_out = nn.Linear(lstm_out_dim, output_dim)
|
|
self.dense_z_heads = nn.ModuleList([
|
|
nn.Linear(lstm_out_dim, head_dim)
|
|
for head_dim in self.head_dims
|
|
])
|
|
|
|
# weight initialization for LSTM layers
|
|
def init_weights(m: nn.Module) -> None:
|
|
if isinstance(m, nn.LSTM):
|
|
for name, p in m.named_parameters():
|
|
if "weight_ih" in name:
|
|
nn.init.xavier_uniform_(p)
|
|
elif "weight_hh" in name:
|
|
nn.init.orthogonal_(p)
|
|
elif "bias" in name:
|
|
nn.init.zeros_(p)
|
|
|
|
self.apply(init_weights)
|
|
|
|
if verbose:
|
|
self.log_arch()
|
|
|
|
def _clamp_rand(self, x: Tensor) -> Tensor:
|
|
return torch.clamp(
|
|
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
|
min=-1.0,
|
|
max=1.0,
|
|
)
|
|
|
|
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
|
inputs = kwargs["inputs"]
|
|
|
|
# data shaped (B, C, T); map to (B, T, C)
|
|
x = inputs.permute(0, 2, 1)
|
|
x = torch.tanh(self.dense_in(x))
|
|
x = self._clamp_rand(x)
|
|
x, hidden = self.lstm(x)
|
|
|
|
z = self.dense_z_out(x)
|
|
zs = torch.cat([head(x) for head in self.dense_z_heads], dim=-1)
|
|
|
|
return z[:, -1, :], zs[:, -1, :]
|
|
|
|
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
|
pred, pred_aux = self(**kwargs)
|
|
labels = kwargs["labels"]
|
|
aux_labels = kwargs.get("auxiliary")
|
|
|
|
if aux_labels is None:
|
|
yield F.mse_loss(pred, labels)
|
|
else:
|
|
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
|
|
|
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
|
return {
|
|
"grad_norm": get_grad_norm(self)
|
|
}
|
|
|
|
def optimizers(
|
|
self,
|
|
**kwargs: Unpack[OptimizerKwargs],
|
|
) -> tuple[Optimizer, ...]:
|
|
"""
|
|
"""
|
|
|
|
default_kwargs: Unpack[OptimizerKwargs] = {
|
|
"lr": 1e-3,
|
|
"eps": 1e-8,
|
|
}
|
|
opt_kwargs = {**default_kwargs, **kwargs}
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
self.parameters(),
|
|
**opt_kwargs,
|
|
)
|
|
|
|
return (optimizer,)
|
|
|
|
def epoch_step(self) -> None:
|
|
return None
|
|
|
|
def epoch_write(
|
|
self,
|
|
writer: SummaryWriter,
|
|
step: int | None = None,
|
|
group: str | None = None,
|
|
**kwargs: Unpack[Kw],
|
|
) -> None:
|
|
return None
|
|
|
|
def log_arch(self) -> None:
|
|
super().log_arch()
|
|
|
|
logger.info(f"| > {self.input_dim=}")
|
|
logger.info(f"| > {self.hidden_dim=}")
|
|
logger.info(f"| > {self.num_layers=}")
|
|
logger.info(f"| > {self.output_dim=}")
|
|
|
|
|
|
class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
|
"""
|
|
Base recurrent convolutional architecture.
|
|
|
|
Computes latents, initial states, and rate estimates from features and
|
|
lambda parameter.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
temporal_dim: int,
|
|
gru_dim: int = 64,
|
|
conv_dim: int = 96,
|
|
num_layers: int = 4,
|
|
conv_kernel_sizes: list[int] | None = None,
|
|
conv_dilations: list[int] | None = None,
|
|
verbose: bool = True,
|
|
) -> None:
|
|
"""
|
|
Parameters:
|
|
input_dim: dimensionality of the input
|
|
output_dim: dimensionality of the output
|
|
gru_dim: dimensionality of each GRU layer output
|
|
conv_dim: dimensionality of each conv layer output
|
|
num_layers: number of gru-conv layer pairs to use
|
|
conv_kernel_sizes: kernel sizes for conv layers
|
|
conv_dilations: dilation settings for conv layers
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
|
|
self.gru_dim = gru_dim
|
|
self.conv_dim = conv_dim
|
|
self.num_layers = num_layers
|
|
self.receptive_field = 0
|
|
|
|
self.conv_kernel_sizes: list[int]
|
|
if conv_kernel_sizes is None:
|
|
self.conv_kernel_sizes = [4] * num_layers
|
|
else:
|
|
self.conv_kernel_sizes = conv_kernel_sizes
|
|
|
|
self.conv_dilations: list[int]
|
|
if conv_dilations is None:
|
|
self.conv_dilations = [1] + [2] * (num_layers - 1)
|
|
else:
|
|
self.conv_dilations = conv_dilations
|
|
|
|
self._gru_layers: nn.ModuleList = nn.ModuleList()
|
|
self._conv_layers: nn.ModuleList = nn.ModuleList()
|
|
|
|
layer_in_dim = gru_dim
|
|
for i in range(self.num_layers):
|
|
gru_layer = nn.GRU(layer_in_dim, gru_dim, batch_first=True)
|
|
self._gru_layers.append(gru_layer)
|
|
layer_in_dim += gru_dim
|
|
|
|
tdnn_layer = TDNNLayer(
|
|
layer_in_dim,
|
|
conv_dim,
|
|
kernel_size=self.conv_kernel_sizes[i],
|
|
dilation=self.conv_dilations[i],
|
|
#pad=False,
|
|
)
|
|
self.receptive_field += tdnn_layer.receptive_field
|
|
|
|
self._conv_layers.append(tdnn_layer)
|
|
layer_in_dim += conv_dim
|
|
|
|
# self.dense_in = nn.Linear(self.input_dim, gru_dim)
|
|
self.dense_in = TDNNLayer(
|
|
self.input_dim,
|
|
gru_dim,
|
|
kernel_size=1,
|
|
pad=False
|
|
)
|
|
# will be (B, T, C), applies indep at each time step across channels
|
|
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
|
|
|
|
# will be (B, C, T), applies indep at each time step across channels
|
|
self.dense_z = TDNNLayer(
|
|
layer_in_dim,
|
|
self.output_dim,
|
|
kernel_size=temporal_dim,
|
|
pad=False,
|
|
)
|
|
|
|
# weight initialization for GRU layers
|
|
def init_weights(module: nn.Module) -> None:
|
|
if isinstance(module, nn.GRU):
|
|
for p in module.named_parameters():
|
|
if p[0].startswith("weight_hh_"):
|
|
nn.init.orthogonal_(p[1])
|
|
|
|
self.apply(init_weights)
|
|
|
|
if verbose:
|
|
self.log_arch()
|
|
|
|
def _clamp_rand(self, x: Tensor) -> Tensor:
|
|
return torch.clamp(
|
|
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
|
|
min=-1.0,
|
|
max=1.0,
|
|
)
|
|
|
|
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
|
inputs = kwargs["inputs"]
|
|
|
|
# embedding shaped (B, C, T)
|
|
x = self._clamp_rand(torch.tanh(self.dense_in(inputs)))
|
|
|
|
# prepare shape (B, T, C) -- for GRU
|
|
x = x.transpose(-2, -1)
|
|
|
|
for gru, conv in zip(self._gru_layers, self._conv_layers, strict=True):
|
|
xg = self._clamp_rand(gru(x)[0])
|
|
x = torch.cat([x, xg], -1)
|
|
|
|
xc = self._clamp_rand(conv(x.transpose(-2, -1)))
|
|
xc = xc.transpose(-2, -1)
|
|
x = torch.cat([x, xc], -1)
|
|
|
|
# z = self.dense_z(x)
|
|
# z = z.transpose(-2, -1)
|
|
|
|
x = x.transpose(-2, -1)
|
|
# map to (B, C, T)
|
|
z = self.dense_z(x)
|
|
|
|
return (z,)
|
|
|
|
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
|
predictions = self(**kwargs)[0]
|
|
labels = kwargs["labels"]
|
|
|
|
# squeeze last dim; we've mapped T -> 1
|
|
predictions = predictions.squeeze(-1)
|
|
|
|
yield F.mse_loss(predictions, labels, reduction="mean")
|
|
|
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
|
with torch.no_grad():
|
|
loss = next(self.loss(**kwargs)).item()
|
|
|
|
predictions = self(**kwargs)[0].squeeze(-1)
|
|
labels = kwargs["labels"]
|
|
mae = F.l1_loss(predictions, labels).item()
|
|
|
|
return {
|
|
"mse": loss,
|
|
"mae": mae,
|
|
"grad_norm": get_grad_norm(self)
|
|
}
|
|
|
|
def optimizers(
|
|
self,
|
|
**kwargs: Unpack[OptimizerKwargs],
|
|
) -> tuple[Optimizer, ...]:
|
|
"""
|
|
"""
|
|
|
|
default_kwargs: Unpack[OptimizerKwargs] = {
|
|
"lr": 1e-3,
|
|
"eps": 1e-8,
|
|
}
|
|
opt_kwargs = {**default_kwargs, **kwargs}
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
self.parameters(),
|
|
**opt_kwargs,
|
|
)
|
|
|
|
return (optimizer,)
|
|
|
|
def epoch_step(self) -> None:
|
|
return None
|
|
|
|
def epoch_write(
|
|
self,
|
|
writer: SummaryWriter,
|
|
step: int | None = None,
|
|
group: str | None = None,
|
|
**kwargs: Unpack[Kw],
|
|
) -> None:
|
|
return None
|
|
|
|
def log_arch(self) -> None:
|
|
super().log_arch()
|
|
|
|
logger.info(f"| > {self.input_dim=}")
|
|
logger.info(f"| > {self.gru_dim=}")
|
|
logger.info(f"| > {self.conv_dim=}")
|
|
logger.info(f"| > {self.num_layers=}")
|
|
logger.info(f"| > {self.conv_kernel_sizes=}")
|
|
logger.info(f"| > {self.conv_dilations=}")
|
|
logger.info(f"| > {self.receptive_field=}")
|
|
logger.info(f"| > {self.output_dim=}")
|