Files
trainlib/trainlib/estimators/rnn.py

499 lines
14 KiB
Python

import logging
from typing import Unpack, NotRequired
from collections.abc import Generator
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from trainlib.estimator import Estimator, EstimatorKwargs
from trainlib.utils.type import OptimizerKwargs
from trainlib.utils.module import get_grad_norm
from trainlib.estimators.tdnn import TDNNLayer
logger: logging.Logger = logging.getLogger(__name__)
class RNNKwargs(EstimatorKwargs):
inputs: Tensor
labels: NotRequired[Tensor]
class LSTM[Kw: RNNKwargs](Estimator[Kw]):
"""
Base RNN architecture.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 64,
num_layers: int = 4,
bidirectional: bool = False,
verbose: bool = True,
) -> None:
"""
Parameters:
input_dim: dimensionality of the input
output_dim: dimensionality of the output
rnn_dim: dimensionality of each RNN layer output
num_layers: number of LSTM layers pairs to use
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dense_in = nn.Linear(input_dim, hidden_dim)
self.lstm = nn.LSTM(
hidden_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
)
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
self.dense_z = nn.Linear(lstm_out_dim, output_dim)
# weight initialization for LSTM layers
def init_weights(m: nn.Module) -> None:
if isinstance(m, nn.LSTM):
for name, p in m.named_parameters():
if "weight_ih" in name:
nn.init.xavier_uniform_(p)
elif "weight_hh" in name:
nn.init.orthogonal_(p)
elif "bias" in name:
nn.init.zeros_(p)
self.apply(init_weights)
if verbose:
self.log_arch()
def _clamp_rand(self, x: Tensor) -> Tensor:
return torch.clamp(
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
min=-1.0,
max=1.0,
)
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# data shaped (B, C, T); map to (B, T, C)
x = inputs.permute(0, 2, 1)
x = torch.tanh(self.dense_in(x))
x = self._clamp_rand(x)
x, hidden = self.lstm(x)
z = self.dense_z(x)
return z[:, -1, :], hidden
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
yield F.mse_loss(predictions, labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
mae = F.l1_loss(predictions, labels).item()
return {
# "loss": loss,
"mse": loss,
"mae": mae,
"grad_norm": get_grad_norm(self)
}
def optimizers(
self,
**kwargs: Unpack[OptimizerKwargs],
) -> tuple[Optimizer, ...]:
"""
"""
default_kwargs: Unpack[OptimizerKwargs] = {
"lr": 1e-3,
"eps": 1e-8,
}
opt_kwargs = {**default_kwargs, **kwargs}
optimizer = torch.optim.AdamW(
self.parameters(),
**opt_kwargs,
)
return (optimizer,)
def epoch_step(self) -> None:
return None
def epoch_write(
self,
writer: SummaryWriter,
step: int | None = None,
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
return None
def log_arch(self) -> None:
super().log_arch()
logger.info(f"| > {self.input_dim=}")
logger.info(f"| > {self.hidden_dim=}")
logger.info(f"| > {self.num_layers=}")
logger.info(f"| > {self.output_dim=}")
class MultiheadLSTMKwargs(EstimatorKwargs):
inputs: Tensor
labels: NotRequired[Tensor]
auxiliary: NotRequired[Tensor]
class MultiheadLSTM[Kw: MultiheadLSTMKwargs](Estimator[Kw]):
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 64,
num_layers: int = 4,
bidirectional: bool = False,
head_dims: list[int] | None = None,
verbose: bool = True,
) -> None:
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.head_dims = head_dims if head_dims is not None else []
self.dense_in = nn.Linear(input_dim, hidden_dim)
self.lstm = nn.LSTM(
hidden_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
)
lstm_out_dim = hidden_dim * (2 if bidirectional else 1)
self.dense_z_out = nn.Linear(lstm_out_dim, output_dim)
self.dense_z_heads = nn.ModuleList([
nn.Linear(lstm_out_dim, head_dim)
for head_dim in self.head_dims
])
# weight initialization for LSTM layers
def init_weights(m: nn.Module) -> None:
if isinstance(m, nn.LSTM):
for name, p in m.named_parameters():
if "weight_ih" in name:
nn.init.xavier_uniform_(p)
elif "weight_hh" in name:
nn.init.orthogonal_(p)
elif "bias" in name:
nn.init.zeros_(p)
self.apply(init_weights)
if verbose:
self.log_arch()
def _clamp_rand(self, x: Tensor) -> Tensor:
return torch.clamp(
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
min=-1.0,
max=1.0,
)
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# data shaped (B, C, T); map to (B, T, C)
x = inputs.permute(0, 2, 1)
x = torch.tanh(self.dense_in(x))
x = self._clamp_rand(x)
x, hidden = self.lstm(x)
z = self.dense_z_out(x)
zs = torch.cat([head(x) for head in self.dense_z_heads], dim=-1)
return z[:, -1, :], zs[:, -1, :]
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
pred, pred_aux = self(**kwargs)
labels = kwargs["labels"]
aux_labels = kwargs.get("auxiliary")
if aux_labels is None:
yield F.mse_loss(pred, labels)
else:
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
return {
"grad_norm": get_grad_norm(self)
}
def optimizers(
self,
**kwargs: Unpack[OptimizerKwargs],
) -> tuple[Optimizer, ...]:
"""
"""
default_kwargs: Unpack[OptimizerKwargs] = {
"lr": 1e-3,
"eps": 1e-8,
}
opt_kwargs = {**default_kwargs, **kwargs}
optimizer = torch.optim.AdamW(
self.parameters(),
**opt_kwargs,
)
return (optimizer,)
def epoch_step(self) -> None:
return None
def epoch_write(
self,
writer: SummaryWriter,
step: int | None = None,
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
return None
def log_arch(self) -> None:
super().log_arch()
logger.info(f"| > {self.input_dim=}")
logger.info(f"| > {self.hidden_dim=}")
logger.info(f"| > {self.num_layers=}")
logger.info(f"| > {self.output_dim=}")
class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
"""
Base recurrent convolutional architecture.
Computes latents, initial states, and rate estimates from features and
lambda parameter.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
temporal_dim: int,
gru_dim: int = 64,
conv_dim: int = 96,
num_layers: int = 4,
conv_kernel_sizes: list[int] | None = None,
conv_dilations: list[int] | None = None,
verbose: bool = True,
) -> None:
"""
Parameters:
input_dim: dimensionality of the input
output_dim: dimensionality of the output
gru_dim: dimensionality of each GRU layer output
conv_dim: dimensionality of each conv layer output
num_layers: number of gru-conv layer pairs to use
conv_kernel_sizes: kernel sizes for conv layers
conv_dilations: dilation settings for conv layers
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.gru_dim = gru_dim
self.conv_dim = conv_dim
self.num_layers = num_layers
self.receptive_field = 0
self.conv_kernel_sizes: list[int]
if conv_kernel_sizes is None:
self.conv_kernel_sizes = [4] * num_layers
else:
self.conv_kernel_sizes = conv_kernel_sizes
self.conv_dilations: list[int]
if conv_dilations is None:
self.conv_dilations = [1] + [2] * (num_layers - 1)
else:
self.conv_dilations = conv_dilations
self._gru_layers: nn.ModuleList = nn.ModuleList()
self._conv_layers: nn.ModuleList = nn.ModuleList()
layer_in_dim = gru_dim
for i in range(self.num_layers):
gru_layer = nn.GRU(layer_in_dim, gru_dim, batch_first=True)
self._gru_layers.append(gru_layer)
layer_in_dim += gru_dim
tdnn_layer = TDNNLayer(
layer_in_dim,
conv_dim,
kernel_size=self.conv_kernel_sizes[i],
dilation=self.conv_dilations[i],
#pad=False,
)
self.receptive_field += tdnn_layer.receptive_field
self._conv_layers.append(tdnn_layer)
layer_in_dim += conv_dim
# self.dense_in = nn.Linear(self.input_dim, gru_dim)
self.dense_in = TDNNLayer(
self.input_dim,
gru_dim,
kernel_size=1,
pad=False
)
# will be (B, T, C), applies indep at each time step across channels
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
# will be (B, C, T), applies indep at each time step across channels
self.dense_z = TDNNLayer(
layer_in_dim,
self.output_dim,
kernel_size=temporal_dim,
pad=False,
)
# weight initialization for GRU layers
def init_weights(module: nn.Module) -> None:
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith("weight_hh_"):
nn.init.orthogonal_(p[1])
self.apply(init_weights)
if verbose:
self.log_arch()
def _clamp_rand(self, x: Tensor) -> Tensor:
return torch.clamp(
x + (1.0 / 127.0) * (torch.rand_like(x) - 0.5),
min=-1.0,
max=1.0,
)
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
inputs = kwargs["inputs"]
# embedding shaped (B, C, T)
x = self._clamp_rand(torch.tanh(self.dense_in(inputs)))
# prepare shape (B, T, C) -- for GRU
x = x.transpose(-2, -1)
for gru, conv in zip(self._gru_layers, self._conv_layers, strict=True):
xg = self._clamp_rand(gru(x)[0])
x = torch.cat([x, xg], -1)
xc = self._clamp_rand(conv(x.transpose(-2, -1)))
xc = xc.transpose(-2, -1)
x = torch.cat([x, xc], -1)
# z = self.dense_z(x)
# z = z.transpose(-2, -1)
x = x.transpose(-2, -1)
# map to (B, C, T)
z = self.dense_z(x)
return (z,)
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
predictions = self(**kwargs)[0]
labels = kwargs["labels"]
# squeeze last dim; we've mapped T -> 1
predictions = predictions.squeeze(-1)
yield F.mse_loss(predictions, labels, reduction="mean")
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
with torch.no_grad():
loss = next(self.loss(**kwargs)).item()
predictions = self(**kwargs)[0].squeeze(-1)
labels = kwargs["labels"]
mae = F.l1_loss(predictions, labels).item()
return {
"mse": loss,
"mae": mae,
"grad_norm": get_grad_norm(self)
}
def optimizers(
self,
**kwargs: Unpack[OptimizerKwargs],
) -> tuple[Optimizer, ...]:
"""
"""
default_kwargs: Unpack[OptimizerKwargs] = {
"lr": 1e-3,
"eps": 1e-8,
}
opt_kwargs = {**default_kwargs, **kwargs}
optimizer = torch.optim.AdamW(
self.parameters(),
**opt_kwargs,
)
return (optimizer,)
def epoch_step(self) -> None:
return None
def epoch_write(
self,
writer: SummaryWriter,
step: int | None = None,
group: str | None = None,
**kwargs: Unpack[Kw],
) -> None:
return None
def log_arch(self) -> None:
super().log_arch()
logger.info(f"| > {self.input_dim=}")
logger.info(f"| > {self.gru_dim=}")
logger.info(f"| > {self.conv_dim=}")
logger.info(f"| > {self.num_layers=}")
logger.info(f"| > {self.conv_kernel_sizes=}")
logger.info(f"| > {self.conv_dilations=}")
logger.info(f"| > {self.receptive_field=}")
logger.info(f"| > {self.output_dim=}")