monobiome/examples/class.py

69 lines
2.5 KiB
Python

class WLBPosteriorEstimator(PosteriorEstimatorTrainer):
"""
Weighted likelihood bootstrap (WLB) estimator.
Trains models to approximate draws from the *weight posterior* (under
Jeffrey's prior).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert not self.use_non_atomic_loss
def get_dataloaders(
self,
starting_round: int = 0,
training_batch_size: int = 200,
validation_fraction: float = 0.1,
resume_training: bool = False,
dataloader_kwargs: dict | None = None,
) -> tuple[data.DataLoader, data.DataLoader]:
"""
Add logic for generating session-specific WLB weights.
This is probably the easiest place to stick some fixed weights on a
point-wise basis for a given training run, and we load them later where
we need them in the ``train()`` loop.
"""
theta, x, prior_masks = self.get_simulations(starting_round)
# generate session specific WLB weights to attach point-wise
N = theta.shape[0]
wlb_z = Exponential(1.0).sample((N,))
wlb_w = (wlb_z / wlb_z.sum()) * N
dataset = data.TensorDataset(theta, x, prior_masks, wlb_w)
num_examples = theta.size(0)
num_training_examples = int((1 - validation_fraction) * num_examples)
num_validation_examples = num_examples - num_training_examples
if not resume_training:
permuted_indices = torch.randperm(num_examples)
self.train_indices, self.val_indices = (
permuted_indices[:num_training_examples],
permuted_indices[num_training_examples:],
)
train_loader_kwargs = {
"batch_size": min(training_batch_size, num_training_examples),
"drop_last": True,
"sampler": SubsetRandomSampler(self.train_indices.tolist()),
}
val_loader_kwargs = {
"batch_size": min(training_batch_size, num_validation_examples),
"shuffle": False,
"drop_last": True,
"sampler": SubsetRandomSampler(self.val_indices.tolist()),
}
if dataloader_kwargs is not None:
train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs)
val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs)
train_loader = data.DataLoader(dataset, **train_loader_kwargs)
val_loader = data.DataLoader(dataset, **val_loader_kwargs)
return train_loader, val_loader