consolidate Trainer object, synchronize event logging
This commit is contained in:
@@ -164,7 +164,7 @@ class Estimator[Kw: EstimatorKwargs](nn.Module):
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
group: str | None = None,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -21,7 +21,7 @@ class MLPKwargs(EstimatorKwargs):
|
||||
labels: NotRequired[Tensor]
|
||||
|
||||
|
||||
class MLP[K: MLPKwargs](Estimator[K]):
|
||||
class MLP[Kw: MLPKwargs](Estimator[Kw]):
|
||||
"""
|
||||
Base MLP architecture.
|
||||
"""
|
||||
@@ -82,19 +82,19 @@ class MLP[K: MLPKwargs](Estimator[K]):
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||
inputs = kwargs["inputs"]
|
||||
x = self._net(inputs)
|
||||
|
||||
return (x,)
|
||||
|
||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||
predictions = self(**kwargs)[0]
|
||||
labels = kwargs["labels"]
|
||||
|
||||
yield F.mse_loss(predictions, labels)
|
||||
|
||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
@@ -135,8 +135,8 @@ class MLP[K: MLPKwargs](Estimator[K]):
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[K],
|
||||
group: str | None = None,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class RNNKwargs(EstimatorKwargs):
|
||||
labels: NotRequired[Tensor]
|
||||
|
||||
|
||||
class LSTM[K: RNNKwargs](Estimator[K]):
|
||||
class LSTM[Kw: RNNKwargs](Estimator[Kw]):
|
||||
"""
|
||||
Base RNN architecture.
|
||||
"""
|
||||
@@ -85,7 +85,7 @@ class LSTM[K: RNNKwargs](Estimator[K]):
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
# data shaped (B, C, T); map to (B, T, C)
|
||||
@@ -97,13 +97,13 @@ class LSTM[K: RNNKwargs](Estimator[K]):
|
||||
|
||||
return z[:, -1, :], hidden
|
||||
|
||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||
predictions = self(**kwargs)[0]
|
||||
labels = kwargs["labels"]
|
||||
|
||||
yield F.mse_loss(predictions, labels)
|
||||
|
||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
@@ -145,8 +145,8 @@ class LSTM[K: RNNKwargs](Estimator[K]):
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[K],
|
||||
group: str | None = None,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@@ -165,7 +165,7 @@ class MultiheadLSTMKwargs(EstimatorKwargs):
|
||||
auxiliary: NotRequired[Tensor]
|
||||
|
||||
|
||||
class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
class MultiheadLSTM[Kw: MultiheadLSTMKwargs](Estimator[Kw]):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
@@ -223,7 +223,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
# data shaped (B, C, T); map to (B, T, C)
|
||||
@@ -237,7 +237,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
|
||||
return z[:, -1, :], zs[:, -1, :]
|
||||
|
||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||
pred, pred_aux = self(**kwargs)
|
||||
labels = kwargs["labels"]
|
||||
aux_labels = kwargs.get("auxiliary")
|
||||
@@ -247,7 +247,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
else:
|
||||
yield F.mse_loss(pred, labels) + F.mse_loss(pred_aux, aux_labels)
|
||||
|
||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||
return {
|
||||
"grad_norm": get_grad_norm(self)
|
||||
}
|
||||
@@ -279,8 +279,8 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[K],
|
||||
group: str | None = None,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@@ -293,7 +293,7 @@ class MultiheadLSTM[K: MultiheadLSTMKwargs](Estimator[K]):
|
||||
logger.info(f"| > {self.output_dim=}")
|
||||
|
||||
|
||||
class ConvGRU[K: RNNKwargs](Estimator[K]):
|
||||
class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
||||
"""
|
||||
Base recurrent convolutional architecture.
|
||||
|
||||
@@ -404,7 +404,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
||||
max=1.0,
|
||||
)
|
||||
|
||||
def forward(self, **kwargs: Unpack[K]) -> tuple[Tensor, ...]:
|
||||
def forward(self, **kwargs: Unpack[Kw]) -> tuple[Tensor, ...]:
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
# embedding shaped (B, C, T)
|
||||
@@ -430,7 +430,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
||||
|
||||
return (z,)
|
||||
|
||||
def loss(self, **kwargs: Unpack[K]) -> Generator[Tensor]:
|
||||
def loss(self, **kwargs: Unpack[Kw]) -> Generator[Tensor]:
|
||||
predictions = self(**kwargs)[0]
|
||||
labels = kwargs["labels"]
|
||||
|
||||
@@ -439,7 +439,7 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
||||
|
||||
yield F.mse_loss(predictions, labels, reduction="mean")
|
||||
|
||||
def metrics(self, **kwargs: Unpack[K]) -> dict[str, float]:
|
||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||
with torch.no_grad():
|
||||
loss = next(self.loss(**kwargs)).item()
|
||||
|
||||
@@ -480,8 +480,8 @@ class ConvGRU[K: RNNKwargs](Estimator[K]):
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
step: int | None = None,
|
||||
val: bool = False,
|
||||
**kwargs: Unpack[K],
|
||||
group: str | None = None,
|
||||
**kwargs: Unpack[Kw],
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -105,7 +105,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
"""
|
||||
|
||||
self._epoch: int = 1
|
||||
self._event_log = defaultdict(lambda: defaultdict(dict))
|
||||
self._summary = defaultdict(lambda: defaultdict(dict))
|
||||
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
|
||||
|
||||
self._val_loss = float("inf")
|
||||
self._best_val_loss = float("inf")
|
||||
@@ -117,7 +118,6 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
train_loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
optimizers: tuple[Optimizer, ...],
|
||||
writer: SummaryWriter,
|
||||
max_grad_norm: float | None = None,
|
||||
) -> list[float]:
|
||||
"""
|
||||
@@ -154,15 +154,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
with tqdm(train_loader, unit="batch") as batches:
|
||||
for i, batch_data in enumerate(batches):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
|
||||
# one-time logging
|
||||
if self._step == 0:
|
||||
writer.add_graph(
|
||||
ModelWrapper(self.estimator),
|
||||
est_kwargs
|
||||
)
|
||||
|
||||
losses = self.estimator.loss(**est_kwargs)
|
||||
|
||||
for o_idx, (loss, optimizer) in enumerate(
|
||||
zip(losses, optimizers, strict=True)
|
||||
):
|
||||
@@ -186,9 +179,6 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
||||
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
||||
|
||||
# grab the assured `inputs` key to measure num samples
|
||||
self._step += len(est_kwargs["inputs"])
|
||||
|
||||
# step estimator hyperparam schedules
|
||||
self.estimator.epoch_step()
|
||||
|
||||
@@ -199,7 +189,6 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
loader: DataLoader,
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
loader_label: str,
|
||||
writer: SummaryWriter,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Perform and record validation scores for a single epoch.
|
||||
@@ -230,8 +219,22 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
with tqdm(loader, unit="batch") as batches:
|
||||
for i, batch_data in enumerate(batches):
|
||||
est_kwargs = batch_estimator_map(batch_data, self)
|
||||
|
||||
losses = self.estimator.loss(**est_kwargs)
|
||||
|
||||
# one-time logging
|
||||
if self._epoch == 0:
|
||||
self._writer.add_graph(
|
||||
ModelWrapper(self.estimator), est_kwargs
|
||||
)
|
||||
# once-per-epoch logging
|
||||
if i == 0:
|
||||
self.estimator.epoch_write(
|
||||
self._writer,
|
||||
step=self._epoch,
|
||||
group=loader_label,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
loss_items = []
|
||||
for o_idx, loss in enumerate(losses):
|
||||
if len(loss_sums) <= o_idx:
|
||||
@@ -247,17 +250,12 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
# log individual loss terms after each batch
|
||||
for o_idx, loss_item in enumerate(loss_items):
|
||||
self._event_log[run_prefix][loader][name][step] = value
|
||||
self._add_summary_item(
|
||||
loader_label, f"loss_{o_idx}", loss_item
|
||||
)
|
||||
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
|
||||
|
||||
# log metrics for batch
|
||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
||||
for metric_name, metric_value in estimator_metrics.items():
|
||||
self._add_summary_item(
|
||||
f"{loader_label}_{metric_name}", metric_value
|
||||
)
|
||||
self._log_event(loader_label, metric_name, metric_value)
|
||||
|
||||
return loss_sums
|
||||
|
||||
@@ -265,8 +263,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
self,
|
||||
loaders: list[DataLoader],
|
||||
batch_estimator_map: Callable[[I, Self], K],
|
||||
labels: list[str],
|
||||
writer: SummaryWriter,
|
||||
loader_labels: list[str],
|
||||
) -> dict[str, list[float]]:
|
||||
"""
|
||||
Evaluate estimator over each provided dataloader.
|
||||
@@ -284,8 +281,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
"""
|
||||
|
||||
return {
|
||||
label: self._eval_epoch(loader, label, batch_estimator_map, writer)
|
||||
for loader, label in zip(loaders, labels, strict=True)
|
||||
label: self._eval_epoch(loader, batch_estimator_map, label)
|
||||
for loader, label in zip(loaders, loader_labels, strict=True)
|
||||
}
|
||||
|
||||
def train[B](
|
||||
@@ -307,7 +304,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
summarize_every: int = 1,
|
||||
chkpt_every: int = 1,
|
||||
resume_latest: bool = False,
|
||||
run_prefix: str | None = None,
|
||||
session_name: str | None = None,
|
||||
summary_writer: SummaryWriter | None = None,
|
||||
aux_loaders: list[DataLoader] | None = None,
|
||||
aux_loader_labels: list[str] | None = None,
|
||||
@@ -362,10 +359,42 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
customized) doesn't consistently yield a known type shape, however,
|
||||
so it's not appropriate to use ``I`` as the callable param type.
|
||||
|
||||
.. todo::
|
||||
.. admonition:: On session management
|
||||
|
||||
- Align eval stage of both train and val; currently train is from
|
||||
before updates, val is from after at a given epoch
|
||||
This method works around an implicit notion of training sessions.
|
||||
Estimators are set during instantiation and effectively coupled
|
||||
with ``Trainer`` instances, but datasets can be supplied
|
||||
dynamically here. One can, for instance, run under one condition
|
||||
(specific dataset, number of epochs, etc), then resume later under
|
||||
another. Relevant details persist across calls: the estimator is
|
||||
still attached, best val scores stowed, current epoch tracked. By
|
||||
default, new ``session_names`` are always generated, but you can
|
||||
write to the same TB location if you using the same
|
||||
``session_name`` across calls; that's about as close to a direct
|
||||
training resume as you could want.
|
||||
|
||||
If restarting training on new datasets, including short
|
||||
fine-tuning on training-plus-validation data, it's often sensible
|
||||
to call ``.reset()`` between ``.train()`` calls. While the same
|
||||
estimator will be used, tracked variables will be wiped; subsequent
|
||||
model updates take place under a fresh epoch, no val losses, and be
|
||||
logged under a separate TB session. This is the general approach to
|
||||
"piecemeal" training, i.e., incremental model updates under varying
|
||||
conditions (often just data changes).
|
||||
|
||||
.. warning::
|
||||
|
||||
Validation convergence when there are multiple losses may be
|
||||
ambiguous. These are cases where certain parameter sets are
|
||||
optimized independently; the sum over these losses may not reflect
|
||||
expected or consistent behavior. For instance, we may record a low
|
||||
cumulative loss early with a small 1st loss and moderate 2nd loss,
|
||||
while later encountering a moderate 1st lost and small 2nd loss. We
|
||||
might prefer the latter case, while ``_converged()`` will stick to
|
||||
the former -- we need to consider possible weighting across losses,
|
||||
or storing possibly several best models (e.g., for each loss, the
|
||||
model that scores best, plus the one scoring best cumulatively,
|
||||
etc).
|
||||
|
||||
Parameters:
|
||||
dataset: dataset to train the estimator
|
||||
@@ -402,16 +431,14 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
logger.info(f"| > with device: {self.device}")
|
||||
logger.info(f"| > core count: {os.cpu_count()}")
|
||||
|
||||
writer: SummaryWriter
|
||||
run_prefix = run_prefix or str(int(time.time()))
|
||||
if summary_writer is None:
|
||||
writer = SummaryWriter(f"{Path(self.tblog_dir, run_prefix)}")
|
||||
else:
|
||||
writer = summary_writer
|
||||
self._session_name = session_name or str(int(time.time()))
|
||||
tblog_path = Path(self.tblog_dir, self._session_name)
|
||||
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
||||
|
||||
aux_loaders = aux_loaders or []
|
||||
aux_loader_labels = aux_loader_labels or []
|
||||
|
||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||
train_loader, val_loader = self.get_dataloaders(
|
||||
dataset,
|
||||
batch_size,
|
||||
@@ -426,12 +453,8 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
loader_labels = ["train", "val", *aux_loader_labels]
|
||||
|
||||
# evaluate model on dataloaders once before training starts
|
||||
self._eval_loaders(loaders, batch_estimator_map, # loader_labels, writer)
|
||||
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
|
||||
|
||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||
|
||||
self._step = 0
|
||||
self._epoch = 1 # start from 1 for logging convenience
|
||||
while self._epoch <= max_epochs and not self._converged(
|
||||
self._epoch, stop_after_epochs
|
||||
):
|
||||
@@ -446,41 +469,22 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
batch_estimator_map,
|
||||
optimizers,
|
||||
max_grad_norm,
|
||||
writer,
|
||||
)
|
||||
self._add_summary_item(
|
||||
"epoch_time_sec", time.time() - epoch_start_time
|
||||
)
|
||||
# writer steps are measured in number of samples; log epoch as an
|
||||
# alternative step resolution for interpreting event timing
|
||||
self._add_summary_item("step", float(self._step))
|
||||
|
||||
|
||||
# once-per-epoch logging
|
||||
self.estimator.epoch_write(
|
||||
writer,
|
||||
step=self._step,
|
||||
val=True,
|
||||
**est_kwargs
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
epoch_end_time = time.time() - epoch_start_time
|
||||
self._log_event("train", "epoch_duration", epoch_end_time)
|
||||
|
||||
loss_sum_map = self._eval_loaders(
|
||||
loaders, loader_labels, batch_estimator_map, writer
|
||||
loaders,
|
||||
batch_estimator_map,
|
||||
loader_labels,
|
||||
)
|
||||
# convergence of multiple losses may be ambiguous
|
||||
val_loss_sums = loss_sum_map["val"]
|
||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
||||
|
||||
if self._epoch % summarize_every == 0:
|
||||
self._summarize(writer, self._epoch, run_prefix)
|
||||
|
||||
self._summarize()
|
||||
if self._epoch % chkpt_every == 0:
|
||||
self.save_model(self.chkpt_dir, self._epoch, run_prefix)
|
||||
|
||||
self.save_model()
|
||||
self._epoch += 1
|
||||
|
||||
return self.estimator
|
||||
@@ -514,6 +518,10 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
) -> tuple[DataLoader, DataLoader]:
|
||||
"""
|
||||
Create training and validation dataloaders for the provided dataset.
|
||||
|
||||
.. todo::
|
||||
|
||||
Decide on policy for empty val dataloaders
|
||||
"""
|
||||
|
||||
if dataset_split_kwargs is None:
|
||||
@@ -571,12 +579,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
return train_loader, val_loader
|
||||
|
||||
def _summarize(
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
epoch: int,
|
||||
run_prefix: str
|
||||
) -> None:
|
||||
def _summarize(self) -> None:
|
||||
"""
|
||||
Flush the training summary to the TensorBoard summary writer.
|
||||
|
||||
@@ -588,21 +591,24 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
averages from epochs 1-10.
|
||||
"""
|
||||
|
||||
summary_values = defaultdict(list)
|
||||
for name, records in self._summary.items():
|
||||
for value, step in records:
|
||||
writer.add_scalar(name, value, step)
|
||||
# writer.add_scalar(name, value, self._epoch)
|
||||
summary_values[name].append(value)
|
||||
self._event_log[run_prefix][name][step] = value
|
||||
epoch_values = defaultdict(lambda: defaultdict(list))
|
||||
for group, records in self._summary.items():
|
||||
for name, steps in records.items():
|
||||
for step, value in steps.items():
|
||||
self._writer.add_scalar(f"{group}-{name}", value, step)
|
||||
if step == self._epoch:
|
||||
epoch_values[group][name].append(value)
|
||||
|
||||
print(f"==== Epoch [{epoch}] summary ====")
|
||||
for name, values in summary_values.items():
|
||||
mean_value = torch.tensor(values).mean().item()
|
||||
print(f"> ({len(values)}) {name} :: {mean_value:.2f}")
|
||||
print(f"==== Epoch [{self._epoch}] summary ====")
|
||||
for group, records in epoch_values.items():
|
||||
for name, values in records.items():
|
||||
mean_value = torch.tensor(values).mean().item()
|
||||
print(
|
||||
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
|
||||
)
|
||||
|
||||
writer.flush()
|
||||
self._summary = defaultdict(list)
|
||||
self._writer.flush()
|
||||
self._summary = defaultdict(lambda: defaultdict(dict))
|
||||
|
||||
def _get_optimizer_parameters(
|
||||
self,
|
||||
@@ -615,11 +621,9 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
if param.grad is not None
|
||||
]
|
||||
|
||||
def _add_summary_item(self, group: str, name: str, value: float) -> None:
|
||||
# self._summary[name].append((value, self._step))
|
||||
self._summary[name].append((value, self._epoch))
|
||||
|
||||
self._event_log[run_prefix][name][step] = value
|
||||
def _log_event(self, group: str, name: str, value: float) -> None:
|
||||
self._summary[group][name][self._epoch] = value
|
||||
self._event_log[self._session_name][group][name][self._epoch] = value
|
||||
|
||||
def get_batch_outputs[B](
|
||||
self,
|
||||
@@ -646,12 +650,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
|
||||
return metrics
|
||||
|
||||
def save_model(
|
||||
self,
|
||||
chkpt_dir: str | Path,
|
||||
epoch: int,
|
||||
run_prefix: str,
|
||||
) -> None:
|
||||
def save_model(self) -> None:
|
||||
"""
|
||||
Save a model checkpoint.
|
||||
"""
|
||||
@@ -661,9 +660,9 @@ class Trainer[I, K: EstimatorKwargs]:
|
||||
model_buff.seek(0)
|
||||
|
||||
model_class = self.estimator.__class__.__name__
|
||||
chkpt_name = f"m_{model_class}-e_{epoch}.pth"
|
||||
chkpt_name = f"m_{model_class}-e_{self._epoch}.pth"
|
||||
|
||||
chkpt_dir = Path(chkpt_dir, run_prefix)
|
||||
chkpt_dir = Path(self.chkpt_dir, self._session_name)
|
||||
chkpt_path = Path(chkpt_dir, chkpt_name)
|
||||
|
||||
chkpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user