make chroma refinements for v1.3.0
This commit is contained in:
68
examples/class.py
Normal file
68
examples/class.py
Normal file
@@ -0,0 +1,68 @@
|
||||
class WLBPosteriorEstimator(PosteriorEstimatorTrainer):
|
||||
"""
|
||||
Weighted likelihood bootstrap (WLB) estimator.
|
||||
|
||||
Trains models to approximate draws from the *weight posterior* (under
|
||||
Jeffrey's prior).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert not self.use_non_atomic_loss
|
||||
|
||||
def get_dataloaders(
|
||||
self,
|
||||
starting_round: int = 0,
|
||||
training_batch_size: int = 200,
|
||||
validation_fraction: float = 0.1,
|
||||
resume_training: bool = False,
|
||||
dataloader_kwargs: dict | None = None,
|
||||
) -> tuple[data.DataLoader, data.DataLoader]:
|
||||
"""
|
||||
Add logic for generating session-specific WLB weights.
|
||||
|
||||
This is probably the easiest place to stick some fixed weights on a
|
||||
point-wise basis for a given training run, and we load them later where
|
||||
we need them in the ``train()`` loop.
|
||||
"""
|
||||
theta, x, prior_masks = self.get_simulations(starting_round)
|
||||
|
||||
# generate session specific WLB weights to attach point-wise
|
||||
N = theta.shape[0]
|
||||
wlb_z = Exponential(1.0).sample((N,))
|
||||
wlb_w = (wlb_z / wlb_z.sum()) * N
|
||||
|
||||
dataset = data.TensorDataset(theta, x, prior_masks, wlb_w)
|
||||
|
||||
num_examples = theta.size(0)
|
||||
num_training_examples = int((1 - validation_fraction) * num_examples)
|
||||
num_validation_examples = num_examples - num_training_examples
|
||||
|
||||
if not resume_training:
|
||||
permuted_indices = torch.randperm(num_examples)
|
||||
self.train_indices, self.val_indices = (
|
||||
permuted_indices[:num_training_examples],
|
||||
permuted_indices[num_training_examples:],
|
||||
)
|
||||
|
||||
train_loader_kwargs = {
|
||||
"batch_size": min(training_batch_size, num_training_examples),
|
||||
"drop_last": True,
|
||||
"sampler": SubsetRandomSampler(self.train_indices.tolist()),
|
||||
}
|
||||
val_loader_kwargs = {
|
||||
"batch_size": min(training_batch_size, num_validation_examples),
|
||||
"shuffle": False,
|
||||
"drop_last": True,
|
||||
"sampler": SubsetRandomSampler(self.val_indices.tolist()),
|
||||
}
|
||||
if dataloader_kwargs is not None:
|
||||
train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs)
|
||||
val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs)
|
||||
|
||||
train_loader = data.DataLoader(dataset, **train_loader_kwargs)
|
||||
val_loader = data.DataLoader(dataset, **val_loader_kwargs)
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
Reference in New Issue
Block a user