consolidate Trainer object, synchronize event logging
This commit is contained in:
@@ -164,7 +164,7 @@ class Estimator[Kw: EstimatorKwargs](nn.Module):
|
|||||||
self,
|
self,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
step: int | None = None,
|
step: int | None = None,
|
||||||
val: bool = False,
|
group: str | None = None,
|
||||||
**kwargs: Unpack[Kw],
|
**kwargs: Unpack[Kw],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class MLPKwargs(EstimatorKwargs):
|
|||||||
labels: NotRequired[Tensor]
|
labels: NotRequired[Tensor]
|
||||||
|
|
||||||
|
|
||||||
class MLP[K: MLPKwargs](Estimator[K]):
|
class MLP[Kw: MLPKwargs](Estimator[Kw]):
|
||||||
"""
|
"""
|
||||||
Base MLP architecture.
|
Base MLP architecture.
|
||||||
"""
|
"""
|
||||||
@@ -82,19 +82,19 @@ class MLP[K: MLPKwargs](Estimator[K]):
|
|||||||
max=1.0,
|
max=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||||
inputs = kwargs["inputs"]
|
inputs = kwargs["inputs"]
|
||||||
x = self._net(inputs)
|
x = self._net(inputs)
|
||||||
|
|
||||||
return (x,)
|
return (x,)
|
||||||
|
|
||||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||||
predictions = self(**kwargs)[0]
|
predictions = self(**kwargs)[0]
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
yield F.mse_loss(predictions, labels)
|
yield F.mse_loss(predictions, labels)
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
loss = next(self.loss(**kwargs)).item()
|
loss = next(self.loss(**kwargs)).item()
|
||||||
|
|
||||||
@@ -135,8 +135,8 @@ class MLP[K: MLPKwargs](Estimator[K]):
|
|||||||
self,
|
self,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
step: int | None = None,
|
step: int | None = None,
|
||||||
val: bool = False,
|
group: str | None = None,
|
||||||
**kwargs: Unpack[K],
|
**kwargs: Unpack[Kw],
|
||||||
) -> None:
|
) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class RNNKwargs(EstimatorKwargs):
|
|||||||
labels: NotRequired[Tensor]
|
labels: NotRequired[Tensor]
|
||||||
|
|
||||||
|
|
||||||
class LSTM[K: RNNKwargs](Estimator[K]):
|
class LSTM[Kw: RNNKwargs](Estimator[Kw]):
|
||||||
"""
|
"""
|
||||||
Base RNN architecture.
|
Base RNN architecture.
|
||||||
"""
|
"""
|
||||||
@@ -85,7 +85,7 @@ class LSTM[K: RNNKwargs](Estimator[K]):
|
|||||||
max=1.0,
|
max=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||||
inputs = kwargs["inputs"]
|
inputs = kwargs["inputs"]
|
||||||
|
|
||||||
# data shaped (B, C, T); map to (B, T, C)
|
# data shaped (B, C, T); map to (B, T, C)
|
||||||
@@ -97,13 +97,13 @@ class LSTM[K: RNNKwargs](Estimator[K]):
|
|||||||
|
|
||||||
return z[:, -1, :], hidden
|
return z[:, -1, :], hidden
|
||||||
|
|
||||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||||
predictions = self(**kwargs)[0]
|
predictions = self(**kwargs)[0]
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
yield F.mse_loss(predictions, labels)
|
yield F.mse_loss(predictions, labels)
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
loss = next(self.loss(**kwargs)).item()
|
loss = next(self.loss(**kwargs)).item()
|
||||||
|
|
||||||
@@ -145,8 +145,8 @@ class LSTM[K: RNNKwargs](Estimator[K]):
|
|||||||
self,
|
self,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
step: int | None = None,
|
step: int | None = None,
|
||||||
val: bool = False,
|
group: str | None = None,
|
||||||
**kwargs: Unpack[K],
|
**kwargs: Unpack[Kw],
|
||||||
) -> None:
|
) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class MultiheadLSTMKwargs(EstimatorKwargs):
|
|||||||
auxiliary: NotRequired[Tensor]
|
auxiliary: NotRequired[Tensor]
|
||||||
|
|
||||||
|
|
||||||
class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
class MultiheadLSTM[Kw: MultiheadLSTMKwargs](Estimator[Kw]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_dim: int,
|
input_dim: int,
|
||||||
@@ -223,7 +223,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
|||||||
max=1.0,
|
max=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||||
inputs = kwargs["inputs"]
|
inputs = kwargs["inputs"]
|
||||||
|
|
||||||
# data shaped (B, C, T); map to (B, T, C)
|
# data shaped (B, C, T); map to (B, T, C)
|
||||||
@@ -237,7 +237,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
|||||||
|
|
||||||
return z[:, -1, :], zs[:, -1, :]
|
return z[:, -1, :], zs[:, -1, :]
|
||||||
|
|
||||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||||
pred, pred_aux = self(**kwargs)
|
pred, pred_aux = self(**kwargs)
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
aux_labels = kwargs.get("auxiliary")
|
aux_labels = kwargs.get("auxiliary")
|
||||||
@@ -247,7 +247,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
|||||||
else:
|
else:
|
||||||
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
|
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
return {
|
return {
|
||||||
"grad_norm": get_grad_norm(self)
|
"grad_norm": get_grad_norm(self)
|
||||||
}
|
}
|
||||||
@@ -279,8 +279,8 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
|||||||
self,
|
self,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
step: int | None = None,
|
step: int | None = None,
|
||||||
val: bool = False,
|
group: str | None = None,
|
||||||
**kwargs: Unpack[K],
|
**kwargs: Unpack[Kw],
|
||||||
) -> None:
|
) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -293,7 +293,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
|||||||
logger.info(f"| > {self.output_dim=}")
|
logger.info(f"| > {self.output_dim=}")
|
||||||
|
|
||||||
|
|
||||||
class ConvGRU[K: RNNKwargs](Estimator[K]):
|
class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
||||||
"""
|
"""
|
||||||
Base recurrent convolutional architecture.
|
Base recurrent convolutional architecture.
|
||||||
|
|
||||||
@@ -404,7 +404,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
|||||||
max=1.0,
|
max=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||||
inputs = kwargs["inputs"]
|
inputs = kwargs["inputs"]
|
||||||
|
|
||||||
# embedding shaped (B, C, T)
|
# embedding shaped (B, C, T)
|
||||||
@@ -430,7 +430,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
|||||||
|
|
||||||
return (z,)
|
return (z,)
|
||||||
|
|
||||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||||
predictions = self(**kwargs)[0]
|
predictions = self(**kwargs)[0]
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
@@ -439,7 +439,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
|||||||
|
|
||||||
yield F.mse_loss(predictions, labels, reduction="mean")
|
yield F.mse_loss(predictions, labels, reduction="mean")
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
loss = next(self.loss(**kwargs)).item()
|
loss = next(self.loss(**kwargs)).item()
|
||||||
|
|
||||||
@@ -480,8 +480,8 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
|||||||
self,
|
self,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
step: int | None = None,
|
step: int | None = None,
|
||||||
val: bool = False,
|
group: str | None = None,
|
||||||
**kwargs: Unpack[K],
|
**kwargs: Unpack[Kw],
|
||||||
) -> None:
|
) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self._epoch: int = 1
|
self._epoch: int = 1
|
||||||
self._event_log = defaultdict(lambda: defaultdict(dict))
|
self._summary = defaultdict(lambda: defaultdict(dict))
|
||||||
|
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
|
||||||
|
|
||||||
self._val_loss = float("inf")
|
self._val_loss = float("inf")
|
||||||
self._best_val_loss = float("inf")
|
self._best_val_loss = float("inf")
|
||||||
@@ -117,7 +118,6 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
batch_estimator_map: Callable[[I, Self], K],
|
batch_estimator_map: Callable[[I, Self], K],
|
||||||
optimizers: tuple[Optimizer, ...],
|
optimizers: tuple[Optimizer, ...],
|
||||||
writer: SummaryWriter,
|
|
||||||
max_grad_norm: float | None = None,
|
max_grad_norm: float | None = None,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
@@ -154,15 +154,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
with tqdm(train_loader, unit="batch") as batches:
|
with tqdm(train_loader, unit="batch") as batches:
|
||||||
for i, batch_data in enumerate(batches):
|
for i, batch_data in enumerate(batches):
|
||||||
est_kwargs = batch_estimator_map(batch_data, self)
|
est_kwargs = batch_estimator_map(batch_data, self)
|
||||||
|
|
||||||
# one-time logging
|
|
||||||
if self._step == 0:
|
|
||||||
writer.add_graph(
|
|
||||||
ModelWrapper(self.estimator),
|
|
||||||
est_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
losses = self.estimator.loss(**est_kwargs)
|
losses = self.estimator.loss(**est_kwargs)
|
||||||
|
|
||||||
for o_idx, (loss, optimizer) in enumerate(
|
for o_idx, (loss, optimizer) in enumerate(
|
||||||
zip(losses, optimizers, strict=True)
|
zip(losses, optimizers, strict=True)
|
||||||
):
|
):
|
||||||
@@ -186,9 +179,6 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
||||||
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
||||||
|
|
||||||
# grab the assured `inputs` key to measure num samples
|
|
||||||
self._step += len(est_kwargs["inputs"])
|
|
||||||
|
|
||||||
# step estimator hyperparam schedules
|
# step estimator hyperparam schedules
|
||||||
self.estimator.epoch_step()
|
self.estimator.epoch_step()
|
||||||
|
|
||||||
@@ -199,7 +189,6 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
loader: DataLoader,
|
loader: DataLoader,
|
||||||
batch_estimator_map: Callable[[I, Self], K],
|
batch_estimator_map: Callable[[I, Self], K],
|
||||||
loader_label: str,
|
loader_label: str,
|
||||||
writer: SummaryWriter,
|
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Perform and record validation scores for a single epoch.
|
Perform and record validation scores for a single epoch.
|
||||||
@@ -230,8 +219,22 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
with tqdm(loader, unit="batch") as batches:
|
with tqdm(loader, unit="batch") as batches:
|
||||||
for i, batch_data in enumerate(batches):
|
for i, batch_data in enumerate(batches):
|
||||||
est_kwargs = batch_estimator_map(batch_data, self)
|
est_kwargs = batch_estimator_map(batch_data, self)
|
||||||
|
|
||||||
losses = self.estimator.loss(**est_kwargs)
|
losses = self.estimator.loss(**est_kwargs)
|
||||||
|
|
||||||
|
# one-time logging
|
||||||
|
if self._epoch == 0:
|
||||||
|
self._writer.add_graph(
|
||||||
|
ModelWrapper(self.estimator), est_kwargs
|
||||||
|
)
|
||||||
|
# once-per-epoch logging
|
||||||
|
if i == 0:
|
||||||
|
self.estimator.epoch_write(
|
||||||
|
self._writer,
|
||||||
|
step=self._epoch,
|
||||||
|
group=loader_label,
|
||||||
|
**est_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
loss_items = []
|
loss_items = []
|
||||||
for o_idx, loss in enumerate(losses):
|
for o_idx, loss in enumerate(losses):
|
||||||
if len(loss_sums) <= o_idx:
|
if len(loss_sums) <= o_idx:
|
||||||
@@ -247,17 +250,12 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
# log individual loss terms after each batch
|
# log individual loss terms after each batch
|
||||||
for o_idx, loss_item in enumerate(loss_items):
|
for o_idx, loss_item in enumerate(loss_items):
|
||||||
self._event_log[run_prefix][loader][name][step] = value
|
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
|
||||||
self._add_summary_item(
|
|
||||||
loader_label, f"loss_{o_idx}", loss_item
|
|
||||||
)
|
|
||||||
|
|
||||||
# log metrics for batch
|
# log metrics for batch
|
||||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||||
for metric_name, metric_value in estimator_metrics.items():
|
for metric_name, metric_value in estimator_metrics.items():
|
||||||
self._add_summary_item(
|
self._log_event(loader_label, metric_name, metric_value)
|
||||||
f"{loader_label}_{metric_name}", metric_value
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss_sums
|
return loss_sums
|
||||||
|
|
||||||
@@ -265,8 +263,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
self,
|
self,
|
||||||
loaders: list[DataLoader],
|
loaders: list[DataLoader],
|
||||||
batch_estimator_map: Callable[[I, Self], K],
|
batch_estimator_map: Callable[[I, Self], K],
|
||||||
labels: list[str],
|
loader_labels: list[str],
|
||||||
writer: SummaryWriter,
|
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
"""
|
"""
|
||||||
Evaluate estimator over each provided dataloader.
|
Evaluate estimator over each provided dataloader.
|
||||||
@@ -284,8 +281,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
label: self._eval_epoch(loader, label, batch_estimator_map, writer)
|
label: self._eval_epoch(loader, batch_estimator_map, label)
|
||||||
for loader, label in zip(loaders, labels, strict=True)
|
for loader, label in zip(loaders, loader_labels, strict=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
def train[B](
|
def train[B](
|
||||||
@@ -307,7 +304,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
summarize_every: int = 1,
|
summarize_every: int = 1,
|
||||||
chkpt_every: int = 1,
|
chkpt_every: int = 1,
|
||||||
resume_latest: bool = False,
|
resume_latest: bool = False,
|
||||||
run_prefix: str | None = None,
|
session_name: str | None = None,
|
||||||
summary_writer: SummaryWriter | None = None,
|
summary_writer: SummaryWriter | None = None,
|
||||||
aux_loaders: list[DataLoader] | None = None,
|
aux_loaders: list[DataLoader] | None = None,
|
||||||
aux_loader_labels: list[str] | None = None,
|
aux_loader_labels: list[str] | None = None,
|
||||||
@@ -362,10 +359,42 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
customized) doesn't consistently yield a known type shape, however,
|
customized) doesn't consistently yield a known type shape, however,
|
||||||
so it's not appropriate to use ``I`` as the callable param type.
|
so it's not appropriate to use ``I`` as the callable param type.
|
||||||
|
|
||||||
.. todo::
|
.. admonition:: On session management
|
||||||
|
|
||||||
- Align eval stage of both train and val; currently train is from
|
This method works around an implicit notion of training sessions.
|
||||||
before updates, val is from after at a given epoch
|
Estimators are set during instantiation and effectively coupled
|
||||||
|
with ``Trainer`` instances, but datasets can be supplied
|
||||||
|
dynamically here. One can, for instance, run under one condition
|
||||||
|
(specific dataset, number of epochs, etc), then resume later under
|
||||||
|
another. Relevant details persist across calls: the estimator is
|
||||||
|
still attached, best val scores stowed, current epoch tracked. By
|
||||||
|
default, new ``session_names`` are always generated, but you can
|
||||||
|
write to the same TB location if you using the same
|
||||||
|
``session_name`` across calls; that's about as close to a direct
|
||||||
|
training resume as you could want.
|
||||||
|
|
||||||
|
If restarting training on new datasets, including short
|
||||||
|
fine-tuning on training-plus-validation data, it's often sensible
|
||||||
|
to call ``.reset()`` between ``.train()`` calls. While the same
|
||||||
|
estimator will be used, tracked variables will be wiped; subsequent
|
||||||
|
model updates take place under a fresh epoch, no val losses, and be
|
||||||
|
logged under a separate TB session. This is the general approach to
|
||||||
|
"piecemeal" training, i.e., incremental model updates under varying
|
||||||
|
conditions (often just data changes).
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Validation convergence when there are multiple losses may be
|
||||||
|
ambiguous. These are cases where certain parameter sets are
|
||||||
|
optimized independently; the sum over these losses may not reflect
|
||||||
|
expected or consistent behavior. For instance, we may record a low
|
||||||
|
cumulative loss early with a small 1st loss and moderate 2nd loss,
|
||||||
|
while later encountering a moderate 1st lost and small 2nd loss. We
|
||||||
|
might prefer the latter case, while ``_converged()`` will stick to
|
||||||
|
the former -- we need to consider possible weighting across losses,
|
||||||
|
or storing possibly several best models (e.g., for each loss, the
|
||||||
|
model that scores best, plus the one scoring best cumulatively,
|
||||||
|
etc).
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
dataset: dataset to train the estimator
|
dataset: dataset to train the estimator
|
||||||
@@ -402,16 +431,14 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
logger.info(f"| > with device: {self.device}")
|
logger.info(f"| > with device: {self.device}")
|
||||||
logger.info(f"| > core count: {os.cpu_count()}")
|
logger.info(f"| > core count: {os.cpu_count()}")
|
||||||
|
|
||||||
writer: SummaryWriter
|
self._session_name = session_name or str(int(time.time()))
|
||||||
run_prefix = run_prefix or str(int(time.time()))
|
tblog_path = Path(self.tblog_dir, self._session_name)
|
||||||
if summary_writer is None:
|
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
||||||
writer = SummaryWriter(f"{Path(self.tblog_dir, run_prefix)}")
|
|
||||||
else:
|
|
||||||
writer = summary_writer
|
|
||||||
|
|
||||||
aux_loaders = aux_loaders or []
|
aux_loaders = aux_loaders or []
|
||||||
aux_loader_labels = aux_loader_labels or []
|
aux_loader_labels = aux_loader_labels or []
|
||||||
|
|
||||||
|
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||||
train_loader, val_loader = self.get_dataloaders(
|
train_loader, val_loader = self.get_dataloaders(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -426,12 +453,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
loader_labels = ["train", "val", *aux_loader_labels]
|
loader_labels = ["train", "val", *aux_loader_labels]
|
||||||
|
|
||||||
# evaluate model on dataloaders once before training starts
|
# evaluate model on dataloaders once before training starts
|
||||||
self._eval_loaders(loaders, batch_estimator_map, # loader_labels, writer)
|
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
|
||||||
|
|
||||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
|
||||||
|
|
||||||
self._step = 0
|
|
||||||
self._epoch = 1 # start from 1 for logging convenience
|
|
||||||
while self._epoch <= max_epochs and not self._converged(
|
while self._epoch <= max_epochs and not self._converged(
|
||||||
self._epoch, stop_after_epochs
|
self._epoch, stop_after_epochs
|
||||||
):
|
):
|
||||||
@@ -446,41 +469,22 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
batch_estimator_map,
|
batch_estimator_map,
|
||||||
optimizers,
|
optimizers,
|
||||||
max_grad_norm,
|
max_grad_norm,
|
||||||
writer,
|
|
||||||
)
|
)
|
||||||
self._add_summary_item(
|
epoch_end_time = time.time() - epoch_start_time
|
||||||
"epoch_time_sec", time.time() - epoch_start_time
|
self._log_event("train", "epoch_duration", epoch_end_time)
|
||||||
)
|
|
||||||
# writer steps are measured in number of samples; log epoch as an
|
|
||||||
# alternative step resolution for interpreting event timing
|
|
||||||
self._add_summary_item("step", float(self._step))
|
|
||||||
|
|
||||||
|
|
||||||
# once-per-epoch logging
|
|
||||||
self.estimator.epoch_write(
|
|
||||||
writer,
|
|
||||||
step=self._step,
|
|
||||||
val=True,
|
|
||||||
**est_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
loss_sum_map = self._eval_loaders(
|
loss_sum_map = self._eval_loaders(
|
||||||
loaders, loader_labels, batch_estimator_map, writer
|
loaders,
|
||||||
|
batch_estimator_map,
|
||||||
|
loader_labels,
|
||||||
)
|
)
|
||||||
# convergence of multiple losses may be ambiguous
|
|
||||||
val_loss_sums = loss_sum_map["val"]
|
val_loss_sums = loss_sum_map["val"]
|
||||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||||
|
|
||||||
if self._epoch % summarize_every == 0:
|
if self._epoch % summarize_every == 0:
|
||||||
self._summarize(writer, self._epoch, run_prefix)
|
self._summarize()
|
||||||
|
|
||||||
if self._epoch % chkpt_every == 0:
|
if self._epoch % chkpt_every == 0:
|
||||||
self.save_model(self.chkpt_dir, self._epoch, run_prefix)
|
self.save_model()
|
||||||
|
|
||||||
self._epoch += 1
|
self._epoch += 1
|
||||||
|
|
||||||
return self.estimator
|
return self.estimator
|
||||||
@@ -514,6 +518,10 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
) -> tuple[DataLoader, DataLoader]:
|
) -> tuple[DataLoader, DataLoader]:
|
||||||
"""
|
"""
|
||||||
Create training and validation dataloaders for the provided dataset.
|
Create training and validation dataloaders for the provided dataset.
|
||||||
|
|
||||||
|
.. todo::
|
||||||
|
|
||||||
|
Decide on policy for empty val dataloaders
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if dataset_split_kwargs is None:
|
if dataset_split_kwargs is None:
|
||||||
@@ -571,12 +579,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
return train_loader, val_loader
|
return train_loader, val_loader
|
||||||
|
|
||||||
def _summarize(
|
def _summarize(self) -> None:
|
||||||
self,
|
|
||||||
writer: SummaryWriter,
|
|
||||||
epoch: int,
|
|
||||||
run_prefix: str
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Flush the training summary to the TensorBoard summary writer.
|
Flush the training summary to the TensorBoard summary writer.
|
||||||
|
|
||||||
@@ -588,21 +591,24 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
averages from epochs 1-10.
|
averages from epochs 1-10.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
summary_values = defaultdict(list)
|
epoch_values = defaultdict(lambda: defaultdict(list))
|
||||||
for name, records in self._summary.items():
|
for group, records in self._summary.items():
|
||||||
for value, step in records:
|
for name, steps in records.items():
|
||||||
writer.add_scalar(name, value, step)
|
for step, value in steps.items():
|
||||||
# writer.add_scalar(name, value, self._epoch)
|
self._writer.add_scalar(f"{group}-{name}", value, step)
|
||||||
summary_values[name].append(value)
|
if step == self._epoch:
|
||||||
self._event_log[run_prefix][name][step] = value
|
epoch_values[group][name].append(value)
|
||||||
|
|
||||||
print(f"==== Epoch [{epoch}] summary ====")
|
print(f"==== Epoch [{self._epoch}] summary ====")
|
||||||
for name, values in summary_values.items():
|
for group, records in epoch_values.items():
|
||||||
|
for name, values in records.items():
|
||||||
mean_value = torch.tensor(values).mean().item()
|
mean_value = torch.tensor(values).mean().item()
|
||||||
print(f"> ({len(values)}) {name} :: {mean_value:.2f}")
|
print(
|
||||||
|
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
writer.flush()
|
self._writer.flush()
|
||||||
self._summary = defaultdict(list)
|
self._summary = defaultdict(lambda: defaultdict(dict))
|
||||||
|
|
||||||
def _get_optimizer_parameters(
|
def _get_optimizer_parameters(
|
||||||
self,
|
self,
|
||||||
@@ -615,11 +621,9 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
if param.grad is not None
|
if param.grad is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
def _add_summary_item(self, group: str, name: str, value: float) -> None:
|
def _log_event(self, group: str, name: str, value: float) -> None:
|
||||||
# self._summary[name].append((value, self._step))
|
self._summary[group][name][self._epoch] = value
|
||||||
self._summary[name].append((value, self._epoch))
|
self._event_log[self._session_name][group][name][self._epoch] = value
|
||||||
|
|
||||||
self._event_log[run_prefix][name][step] = value
|
|
||||||
|
|
||||||
def get_batch_outputs[B](
|
def get_batch_outputs[B](
|
||||||
self,
|
self,
|
||||||
@@ -646,12 +650,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def save_model(
|
def save_model(self) -> None:
|
||||||
self,
|
|
||||||
chkpt_dir: str | Path,
|
|
||||||
epoch: int,
|
|
||||||
run_prefix: str,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Save a model checkpoint.
|
Save a model checkpoint.
|
||||||
"""
|
"""
|
||||||
@@ -661,9 +660,9 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
model_buff.seek(0)
|
model_buff.seek(0)
|
||||||
|
|
||||||
model_class = self.estimator.__class__.__name__
|
model_class = self.estimator.__class__.__name__
|
||||||
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
|
chkpt_name = f"m_{model_class}-e_{self._epoch}.pth"
|
||||||
|
|
||||||
chkpt_dir = Path(chkpt_dir, run_prefix)
|
chkpt_dir = Path(self.chkpt_dir, self._session_name)
|
||||||
chkpt_path = Path(chkpt_dir, chkpt_name)
|
chkpt_path = Path(chkpt_dir, chkpt_name)
|
||||||
|
|
||||||
chkpt_dir.mkdir(parents=True, exist_ok=True)
|
chkpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user