Compare commits
4 Commits
85d176862e
...
0.3.1
| Author | SHA1 | Date | |
|---|---|---|---|
| fdccb4c5eb | |||
| ba0c804d5e | |||
| b59749c8d8 | |||
| a395a08d5c |
11
TODO.md
Normal file
11
TODO.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# Long-term
|
||||||
|
- Implement a dataloader in-house, with a clear, lightweight mechanism for
|
||||||
|
collection-of-structures to structure-of-collections. For multi-proc handling
|
||||||
|
(happens in torch's dataloader, as well as the BatchedDataset for two
|
||||||
|
different purposes), we should rely on (a hopefully more stable) `execlib`.
|
||||||
|
- `Domains` may be externalized (`co3` or `convlib`)
|
||||||
|
- Up next: CLI, fully JSON-ification of model selection + train.
|
||||||
|
- Consider a "multi-train" alternative (or arg support in `train()`) for
|
||||||
|
training many "rollouts" from the same base estimator (basically forks under
|
||||||
|
different seeds). For architecture benchmarking above all, seeing average
|
||||||
|
training behavior. Consider corresponding `Plotter` methods (error bars)
|
||||||
29
example/example.json
Normal file
29
example/example.json
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"estimator_name": "mlp",
|
||||||
|
"dataset_name": "random_xy_dataset",
|
||||||
|
"dataloader_name": "supervised_data_loader",
|
||||||
|
"estimator_kwargs": {
|
||||||
|
"input_dim": 4,
|
||||||
|
"output_dim": 2
|
||||||
|
},
|
||||||
|
"dataset_kwargs": {
|
||||||
|
"num_samples": 100000,
|
||||||
|
"preload": true,
|
||||||
|
"input_dim": 4,
|
||||||
|
"output_dim": 2
|
||||||
|
},
|
||||||
|
"dataset_split_fracs": {
|
||||||
|
"train": 0.4,
|
||||||
|
"val": 0.3,
|
||||||
|
"aux": [0.3]
|
||||||
|
},
|
||||||
|
"dataloader_kwargs": {
|
||||||
|
"batch_size": 16
|
||||||
|
},
|
||||||
|
"train_kwargs": {
|
||||||
|
"summarize_every": 20,
|
||||||
|
"max_epochs": 100,
|
||||||
|
"stop_after_epochs": 100
|
||||||
|
},
|
||||||
|
"load_only": false
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "trainlib"
|
name = "trainlib"
|
||||||
version = "0.1.2"
|
version = "0.3.1"
|
||||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
18
trainlib/__main__.py
Normal file
18
trainlib/__main__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from trainlib.cli import create_parser
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = create_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# skim off log level to handle higher-level option
|
||||||
|
if hasattr(args, "log_level") and args.log_level is not None:
|
||||||
|
logging.basicConfig(level=args.log_level)
|
||||||
|
|
||||||
|
args.func(args) if "func" in args else parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
26
trainlib/cli/__init__.py
Normal file
26
trainlib/cli/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from trainlib.cli import train
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="trainlib cli",
|
||||||
|
# formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
type=int,
|
||||||
|
metavar="int",
|
||||||
|
choices=[10, 20, 30, 40, 50],
|
||||||
|
help="Log level: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(help="subcommand help")
|
||||||
|
train.register_parser(subparsers)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
164
trainlib/cli/train.py
Normal file
164
trainlib/cli/train.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
from typing import Any
|
||||||
|
from argparse import _SubParsersAction
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from trainlib.trainer import Trainer
|
||||||
|
from trainlib.datasets import dataset_map
|
||||||
|
from trainlib.estimator import Estimator
|
||||||
|
from trainlib.estimators import estimator_map
|
||||||
|
from trainlib.dataloaders import dataloader_map
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_run() -> None:
|
||||||
|
# prepare cuda memory
|
||||||
|
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
||||||
|
print(f"CUDA allocated: {memory_allocated}GB")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
estimator_name: str,
|
||||||
|
dataset_name: str,
|
||||||
|
dataloader_name: str,
|
||||||
|
estimator_kwargs: dict[str, Any] | None = None,
|
||||||
|
dataset_kwargs: dict[str, Any] | None = None,
|
||||||
|
dataset_split_fracs: dict[str, Any] | None = None,
|
||||||
|
dataset_split_kwargs: dict[str, Any] | None = None,
|
||||||
|
dataloader_kwargs: dict[str, Any] | None = None,
|
||||||
|
trainer_kwargs: dict[str, Any] | None = None,
|
||||||
|
train_kwargs: dict[str, Any] | None = None,
|
||||||
|
load_only: bool = False,
|
||||||
|
) -> Trainer | Estimator:
|
||||||
|
|
||||||
|
try:
|
||||||
|
estimator_cls = estimator_map[estimator_name]
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid estimator name '{estimator_name}',"
|
||||||
|
f"must be one of {estimator_map.keys()}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset_cls = dataset_map[dataset_name]
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid dataset name '{dataset_name}',"
|
||||||
|
f"must be one of {dataset_map.keys()}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataloader_cls = dataloader_map[dataloader_name]
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid dataloader name '{dataloader_name}',"
|
||||||
|
f"must be one of {dataloader_map.keys()}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
estimator_kwargs = estimator_kwargs or {}
|
||||||
|
dataset_kwargs = dataset_kwargs or {}
|
||||||
|
dataset_split_fracs = dataset_split_fracs or {}
|
||||||
|
dataset_split_kwargs = dataset_split_kwargs or {}
|
||||||
|
dataloader_kwargs = dataloader_kwargs or {}
|
||||||
|
trainer_kwargs = trainer_kwargs or {}
|
||||||
|
train_kwargs = train_kwargs or {}
|
||||||
|
|
||||||
|
default_estimator_kwargs = {}
|
||||||
|
default_dataset_kwargs = {}
|
||||||
|
default_dataset_split_kwargs = {}
|
||||||
|
default_dataset_split_fracs = {"train": 1.0, "val": 0.0, "aux": []}
|
||||||
|
default_dataloader_kwargs = {}
|
||||||
|
default_trainer_kwargs = {}
|
||||||
|
default_train_kwargs = {}
|
||||||
|
|
||||||
|
estimator_kwargs = {**default_estimator_kwargs, **estimator_kwargs}
|
||||||
|
dataset_kwargs = {**default_dataset_kwargs, **dataset_kwargs}
|
||||||
|
dataset_split_kwargs = {**default_dataset_split_kwargs, **dataset_split_kwargs}
|
||||||
|
dataset_split_fracs = {**default_dataset_split_fracs, **dataset_split_fracs}
|
||||||
|
dataloader_kwargs = {**default_dataloader_kwargs, **dataloader_kwargs}
|
||||||
|
trainer_kwargs = {**default_trainer_kwargs, **trainer_kwargs}
|
||||||
|
train_kwargs = {**default_train_kwargs, **train_kwargs}
|
||||||
|
|
||||||
|
estimator = estimator_cls(**estimator_kwargs)
|
||||||
|
dataset = dataset_cls(**dataset_kwargs)
|
||||||
|
train_dataset, val_dataset, *aux_datasets = dataset.split(
|
||||||
|
fracs=[
|
||||||
|
dataset_split_fracs["train"],
|
||||||
|
dataset_split_fracs["val"],
|
||||||
|
*dataset_split_fracs["aux"]
|
||||||
|
],
|
||||||
|
**dataset_split_kwargs
|
||||||
|
)
|
||||||
|
train_loader = dataloader_cls(train_dataset, **dataloader_kwargs)
|
||||||
|
val_loader = dataloader_cls(val_dataset, **dataloader_kwargs)
|
||||||
|
aux_loaders = [
|
||||||
|
dataloader_cls(aux_dataset, **dataloader_kwargs)
|
||||||
|
for aux_dataset in aux_datasets
|
||||||
|
]
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
estimator,
|
||||||
|
**trainer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_only:
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
return trainer.train(
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
aux_loaders=aux_loaders,
|
||||||
|
**train_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_from_json(
|
||||||
|
parameters_json: str | None = None,
|
||||||
|
parameters_file: str | None = None,
|
||||||
|
) -> Trainer | Estimator:
|
||||||
|
if not (parameters_json or parameters_file):
|
||||||
|
raise ValueError("parameter json or file required")
|
||||||
|
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
if parameters_json:
|
||||||
|
try:
|
||||||
|
parameters = json.loads(parameters_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON format: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading JSON parameters: {e}") from e
|
||||||
|
|
||||||
|
elif parameters_file:
|
||||||
|
try:
|
||||||
|
with open(parameters_file, encoding="utf-8") as f:
|
||||||
|
parameters = json.load(f)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise ValueError(f"JSON file not found: {parameters_file}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading JSON parameters: {e}") from e
|
||||||
|
|
||||||
|
return run(**parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_train(args: argparse.Namespace) -> None:
|
||||||
|
run_from_json(args.parameters_json, args.parameters_file)
|
||||||
|
|
||||||
|
|
||||||
|
def register_parser(subparsers: _SubParsersAction) -> None:
|
||||||
|
parser = subparsers.add_parser("train", help="run training loop")
|
||||||
|
parser.add_argument(
|
||||||
|
"--parameters-json",
|
||||||
|
type=str,
|
||||||
|
help="Raw JSON string with train parameters",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--parameters-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to JSON file with train parameters",
|
||||||
|
)
|
||||||
|
parser.set_defaults(func=handle_train)
|
||||||
102
trainlib/dataloader.py
Normal file
102
trainlib/dataloader.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
This class took me a long time to really settle into. It's a connector, and it
|
||||||
|
feels redundant in many ways, so I've nearly deleted it several times while
|
||||||
|
talking through the design. But in total, I think it serves a clear purpose.
|
||||||
|
Reasons:
|
||||||
|
|
||||||
|
- Need a typed dataloader, even if I know the type of my attached transform
|
||||||
|
- Need a new scope that uses the same base dataset without interfering with the
|
||||||
|
transform attribute; a design that sets or relies on that is subject to
|
||||||
|
conflict
|
||||||
|
|
||||||
|
- Why not just use vanilla DataLoaders?
|
||||||
|
|
||||||
|
I'd like to, but the two reasons above make it clear why this is challenging:
|
||||||
|
I don't get static checks on the structures returned during iteration, and
|
||||||
|
while you can control ad hoc data transforms via dataset ``post_transforms``,
|
||||||
|
things can get messy if you need to do that for many transforms using the
|
||||||
|
same dataset (without copying). Simplest way around this is just a new scope
|
||||||
|
with the same underlying dataset instance and a transform wrapper around the
|
||||||
|
iterator; no interference with object attributes.
|
||||||
|
|
||||||
|
This is really just meant as the minimum viable logic needed to accomplish the
|
||||||
|
above - it's a very lightweight wrapper on the base ``DataLoader`` object.
|
||||||
|
There's an explicit type upper bound ``Kw: EstimatorKwargs``, but it is
|
||||||
|
otherwise a completely general transform over dataloader batches, highlighting
|
||||||
|
that it's *mostly* here to place nice with type checks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Unpack
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from trainlib.dataset import BatchedDataset
|
||||||
|
from trainlib.estimator import EstimatorKwargs
|
||||||
|
from trainlib.utils.type import LoaderKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorDataLoader[B, Kw: EstimatorKwargs]:
|
||||||
|
"""
|
||||||
|
Data loaders for estimators.
|
||||||
|
|
||||||
|
This class exists to connect batched data from datasets to the expected
|
||||||
|
representation for estimator methods. Datasets may be developed
|
||||||
|
independently from a given model structures, and models should be trainable
|
||||||
|
under any such data. We need a way to ensure the batched groups of items we
|
||||||
|
get from dataloaders match on a type level, i.e., can be reshaped into the
|
||||||
|
expected ``Kw`` signature.
|
||||||
|
|
||||||
|
Note: batch structure ``B`` cannot be directly inferred from type variables
|
||||||
|
exposed by ``BatchedDatasets`` (namely ``R`` and ``I``). What's returned by
|
||||||
|
a data loader wrapping any such dataset can be arbitrary (depending on the
|
||||||
|
``collate_fn``), with default behavior being fairly consistent under nested
|
||||||
|
collections but challenging to accurately type.
|
||||||
|
|
||||||
|
.. todo::
|
||||||
|
|
||||||
|
To log (have changed for Trainer):
|
||||||
|
|
||||||
|
- New compact eval pipeline for train/val/auxiliary dataloaders.
|
||||||
|
Somewhat annoying logic, but handled consistently
|
||||||
|
- Convergence tracker will dynamically use training loss (early
|
||||||
|
stopping) when a validation set isn't provided. Same mechanics for
|
||||||
|
stagnant epochs (although early stopping is generally a little more
|
||||||
|
nuanced, having a rate-based stopper, b/c train loss generally quite
|
||||||
|
monotonic). So that's to be updated, plus room for possible model
|
||||||
|
selection strategies later.
|
||||||
|
- Logging happens at each batch, but we append to an epoch-indexed list
|
||||||
|
and later average. There was a bug in the last round of testing that
|
||||||
|
I didn't pick up where I was just overwriting summaries using the
|
||||||
|
last seen batch.
|
||||||
|
- Reworked general dataset/dataloader handling for main train loop, now
|
||||||
|
accepting objects of this class to bridge estimator and dataset
|
||||||
|
communication. This cleans up the batch mapping model.
|
||||||
|
- TODO: implement a version of this that canonically works with the
|
||||||
|
device passing plus EstimatorKwargs input; this is the last fuzzy bit
|
||||||
|
I think.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: BatchedDataset,
|
||||||
|
**dataloader_kwargs: Unpack[LoaderKwargs],
|
||||||
|
) -> None:
|
||||||
|
self._dataloader = DataLoader(dataset, **dataloader_kwargs)
|
||||||
|
|
||||||
|
def batch_to_est_kwargs(self, batch_data: B) -> Kw:
|
||||||
|
"""
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Even if we have a concrete shape for the output kwarg dict for base
|
||||||
|
estimators (requiring a tensor "inputs" attribute), we don't
|
||||||
|
presuppose how a given batch object will map into this dict
|
||||||
|
structure.
|
||||||
|
|
||||||
|
return EstimatorKwargs({"inputs":0})
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Kw]:
|
||||||
|
return map(self.batch_to_est_kwargs, self._dataloader)
|
||||||
12
trainlib/dataloaders/__init__.py
Normal file
12
trainlib/dataloaders/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from trainlib.dataloader import EstimatorDataLoader
|
||||||
|
from trainlib.utils.text import camel_to_snake
|
||||||
|
from trainlib.dataloaders.memory import SupervisedDataLoader
|
||||||
|
|
||||||
|
_dataloaders = [
|
||||||
|
SupervisedDataLoader,
|
||||||
|
]
|
||||||
|
dataloader_map: dict[str, type[EstimatorDataLoader]] = {
|
||||||
|
camel_to_snake(_dataloader.__name__): _dataloader
|
||||||
|
for _dataloader in _dataloaders
|
||||||
|
}
|
||||||
|
|
||||||
17
trainlib/dataloaders/memory.py
Normal file
17
trainlib/dataloaders/memory.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from trainlib.estimator import SupervisedKwargs
|
||||||
|
from trainlib.dataloader import EstimatorDataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedDataLoader(
|
||||||
|
EstimatorDataLoader[tuple[Tensor, Tensor], SupervisedKwargs]
|
||||||
|
):
|
||||||
|
def batch_to_est_kwargs(
|
||||||
|
self,
|
||||||
|
batch_data: tuple[Tensor, Tensor]
|
||||||
|
) -> SupervisedKwargs:
|
||||||
|
return SupervisedKwargs(
|
||||||
|
inputs=batch_data[0],
|
||||||
|
labels=batch_data[1],
|
||||||
|
)
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
from trainlib.dataset import BatchedDataset
|
||||||
|
from trainlib.utils.text import camel_to_snake
|
||||||
|
from trainlib.datasets.memory import RandomXYDataset
|
||||||
|
|
||||||
|
_datasets = [
|
||||||
|
RandomXYDataset,
|
||||||
|
]
|
||||||
|
dataset_map: dict[str, type[BatchedDataset]] = {
|
||||||
|
camel_to_snake(_dataset.__name__): _dataset
|
||||||
|
for _dataset in _datasets
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -73,13 +73,33 @@ class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
|
|||||||
|
|
||||||
from typing import Unpack
|
from typing import Unpack
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.utils.data import TensorDataset
|
||||||
|
|
||||||
from trainlib.domain import SequenceDomain
|
from trainlib.domain import SequenceDomain
|
||||||
from trainlib.dataset import TupleDataset, DatasetKwargs
|
from trainlib.dataset import TupleDataset, DatasetKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class RandomXYDataset(TupleDataset[Tensor]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_samples: int,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
**kwargs: Unpack[DatasetKwargs],
|
||||||
|
) -> None:
|
||||||
|
domain = SequenceDomain[tuple[Tensor, Tensor]](
|
||||||
|
TensorDataset(
|
||||||
|
torch.randn((num_samples, input_dim)),
|
||||||
|
torch.randn((num_samples, output_dim))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(domain, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SlidingWindowDataset(TupleDataset[Tensor]):
|
class SlidingWindowDataset(TupleDataset[Tensor]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -88,12 +108,26 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
lookahead: int = 1,
|
lookahead: int = 1,
|
||||||
num_windows: int = 1,
|
num_windows: int = 1,
|
||||||
|
pad_mode: str = "constant",
|
||||||
|
#fill_with: str = "zero",
|
||||||
**kwargs: Unpack[DatasetKwargs],
|
**kwargs: Unpack[DatasetKwargs],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
TODO: implement options for `fill_with`; currently just passing
|
||||||
|
through a `pad_mode` the Functional call, which does the job
|
||||||
|
|
||||||
|
fill_with: strategy to use for padding values in windows
|
||||||
|
- `zero`: fill with zeros
|
||||||
|
- `left`: use nearest window column (repeat leftmost)
|
||||||
|
- `mean`: fill with the window mean
|
||||||
|
"""
|
||||||
|
|
||||||
self.lookback = lookback
|
self.lookback = lookback
|
||||||
self.offset = offset
|
self.offset = offset
|
||||||
self.lookahead = lookahead
|
self.lookahead = lookahead
|
||||||
self.num_windows = num_windows
|
self.num_windows = num_windows
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
|
||||||
super().__init__(domain, **kwargs)
|
super().__init__(domain, **kwargs)
|
||||||
|
|
||||||
@@ -103,7 +137,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
batch_index: int,
|
batch_index: int,
|
||||||
) -> list[tuple[Tensor, ...]]:
|
) -> list[tuple[Tensor, ...]]:
|
||||||
"""
|
"""
|
||||||
Backward pads first sequence over (lookback-1) length, and steps the
|
Backward pads window sequences over (lookback-1) length, and steps the
|
||||||
remaining items forward by the lookahead.
|
remaining items forward by the lookahead.
|
||||||
|
|
||||||
Batch data:
|
Batch data:
|
||||||
@@ -146,7 +180,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
exceeds the offset.
|
exceeds the offset.
|
||||||
|
|
||||||
To get windows starting with the first index at the left: we first set
|
To get windows starting with the first index at the left: we first set
|
||||||
out window size (call it L), determined by `lookback`. Then the
|
our window size (call it L), determined by `lookback`. Then the
|
||||||
rightmost index we want will be `L-1`, which determines our `offset`
|
rightmost index we want will be `L-1`, which determines our `offset`
|
||||||
setting.
|
setting.
|
||||||
|
|
||||||
@@ -173,7 +207,7 @@ class SlidingWindowDataset(TupleDataset[Tensor]):
|
|||||||
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
|
# for window sized `lb`, we pad with `lb-1` zeros. We then take off
|
||||||
# the amount of our offset, which in the extreme cases does no
|
# the amount of our offset, which in the extreme cases does no
|
||||||
# padding.
|
# padding.
|
||||||
xip = F.pad(t, ((lb-1) - off, 0))
|
xip = F.pad(t, ((lb-1) - off, 0), mode=self.pad_mode)
|
||||||
|
|
||||||
# extract sliding windows over the padded tensor
|
# extract sliding windows over the padded tensor
|
||||||
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for
|
# unfold(-1, lb, 1) slides over the last dim, 1 step at a time, for
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ class EstimatorKwargs(TypedDict):
|
|||||||
inputs: Tensor
|
inputs: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedKwargs(EstimatorKwargs):
|
||||||
|
labels: Tensor
|
||||||
|
|
||||||
|
|
||||||
class Estimator[Kw: EstimatorKwargs](nn.Module):
|
class Estimator[Kw: EstimatorKwargs](nn.Module):
|
||||||
"""
|
"""
|
||||||
Estimator base class.
|
Estimator base class.
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from trainlib.estimator import Estimator
|
||||||
|
from trainlib.utils.text import camel_to_snake
|
||||||
|
from trainlib.estimators.mlp import MLP
|
||||||
|
from trainlib.estimators.rnn import LSTM, ConvGRU, MultiheadLSTM
|
||||||
|
|
||||||
|
_estimators: list[type[Estimator]] = [
|
||||||
|
MLP,
|
||||||
|
LSTM,
|
||||||
|
MultiheadLSTM,
|
||||||
|
ConvGRU,
|
||||||
|
]
|
||||||
|
estimator_map: dict[str, type[Estimator]] = {
|
||||||
|
camel_to_snake(_estimator.__name__): _estimator
|
||||||
|
for _estimator in _estimators
|
||||||
|
}
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class MLP[Kw: MLPKwargs](Estimator[Kw]):
|
|||||||
mae = F.l1_loss(predictions, labels).item()
|
mae = F.l1_loss(predictions, labels).item()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mse": loss,
|
# "mse": loss,
|
||||||
"mae": mae,
|
"mae": mae,
|
||||||
"grad_norm": get_grad_norm(self)
|
"grad_norm": get_grad_norm(self)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from torch import nn, Tensor
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from trainlib.utils import op
|
||||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||||
from trainlib.utils.type import OptimizerKwargs
|
from trainlib.utils.type import OptimizerKwargs
|
||||||
from trainlib.utils.module import get_grad_norm
|
from trainlib.utils.module import get_grad_norm
|
||||||
@@ -102,6 +103,7 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
yield F.mse_loss(predictions, labels)
|
yield F.mse_loss(predictions, labels)
|
||||||
|
#yield F.l1_loss(predictions, labels)
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -109,12 +111,16 @@ class LSTM[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
|
|
||||||
predictions = self(**kwargs)[0]
|
predictions = self(**kwargs)[0]
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
|
mse = F.mse_loss(predictions, labels).item()
|
||||||
mae = F.l1_loss(predictions, labels).item()
|
mae = F.l1_loss(predictions, labels).item()
|
||||||
|
r2 = op.r2_score(predictions, labels).item()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
# "loss": loss,
|
"loss": loss,
|
||||||
"mse": loss,
|
"mse": mse,
|
||||||
"mae": mae,
|
"mae": mae,
|
||||||
|
"r2": r2,
|
||||||
"grad_norm": get_grad_norm(self)
|
"grad_norm": get_grad_norm(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,7 +383,8 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
# will be (B, T, C), applies indep at each time step across channels
|
# will be (B, T, C), applies indep at each time step across channels
|
||||||
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
|
# self.dense_z = nn.Linear(layer_in_dim, self.output_dim)
|
||||||
|
|
||||||
# will be (B, C, T), applies indep at each time step across channels
|
# will be (B, Co, 1), applies indep at each channel across temporal dim
|
||||||
|
# size time steps
|
||||||
self.dense_z = TDNNLayer(
|
self.dense_z = TDNNLayer(
|
||||||
layer_in_dim,
|
layer_in_dim,
|
||||||
self.output_dim,
|
self.output_dim,
|
||||||
@@ -438,6 +445,7 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
predictions = predictions.squeeze(-1)
|
predictions = predictions.squeeze(-1)
|
||||||
|
|
||||||
yield F.mse_loss(predictions, labels, reduction="mean")
|
yield F.mse_loss(predictions, labels, reduction="mean")
|
||||||
|
#yield F.l1_loss(predictions, labels)
|
||||||
|
|
||||||
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
def metrics(self, **kwargs: Unpack[Kw]) -> dict[str, float]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -445,11 +453,16 @@ class ConvGRU[Kw: RNNKwargs](Estimator[Kw]):
|
|||||||
|
|
||||||
predictions = self(**kwargs)[0].squeeze(-1)
|
predictions = self(**kwargs)[0].squeeze(-1)
|
||||||
labels = kwargs["labels"]
|
labels = kwargs["labels"]
|
||||||
|
|
||||||
|
mse = F.mse_loss(predictions, labels).item()
|
||||||
mae = F.l1_loss(predictions, labels).item()
|
mae = F.l1_loss(predictions, labels).item()
|
||||||
|
r2 = op.r2_score(predictions, labels).item()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mse": loss,
|
"loss": loss,
|
||||||
|
"mse": mse,
|
||||||
"mae": mae,
|
"mae": mae,
|
||||||
|
"r2": r2,
|
||||||
"grad_norm": get_grad_norm(self)
|
"grad_norm": get_grad_norm(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
from functools import partial
|
from typing import Any
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from trainlib.trainer import Trainer
|
from trainlib.trainer import Trainer
|
||||||
from trainlib.estimator import EstimatorKwargs
|
from trainlib.estimator import EstimatorKwargs
|
||||||
|
from trainlib.dataloader import EstimatorDataLoader
|
||||||
from trainlib.utils.type import AxesArray, SubplotsKwargs
|
from trainlib.utils.type import AxesArray, SubplotsKwargs
|
||||||
|
|
||||||
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
|
type SubplotFn = Callable[[plt.Axes, int, Tensor, Tensor], None]
|
||||||
type ContextFn = Callable[[plt.Axes, str], None]
|
type ContextFn = Callable[[plt.Axes, str], None]
|
||||||
|
|
||||||
|
|
||||||
class Plotter[B, K: EstimatorKwargs]:
|
class Plotter[Kw: EstimatorKwargs]:
|
||||||
"""
|
"""
|
||||||
TODOs:
|
TODOs:
|
||||||
|
|
||||||
@@ -26,14 +26,15 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
intervals broken over the training epochs at 0, 50, 100, 150, ... and
|
||||||
highlight the best one, even if that's not actually the single best
|
highlight the best one, even if that's not actually the single best
|
||||||
epoch)
|
epoch)
|
||||||
|
- Implement data and dimension limits; in the instance dataloaders have
|
||||||
|
huge numbers of samples or labels are high-dimensional
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
trainer: Trainer[..., K],
|
trainer: Trainer[Any, Kw],
|
||||||
dataloaders: list[DataLoader],
|
dataloaders: list[EstimatorDataLoader[Any, Kw]],
|
||||||
batch_estimator_map: Callable[[B, Trainer], ...],
|
kw_to_actual: Callable[[Kw], Tensor],
|
||||||
estimator_to_output_map: Callable[[K], ...],
|
|
||||||
dataloader_labels: list[str] | None = None,
|
dataloader_labels: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
@@ -41,47 +42,21 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
self.dataloader_labels = (
|
self.dataloader_labels = (
|
||||||
dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
|
dataloader_labels or list(map(str, range(1, len(dataloaders)+1)))
|
||||||
)
|
)
|
||||||
self.batch_estimator_map = batch_estimator_map
|
self.kw_to_actual = kw_to_actual
|
||||||
self.estimator_to_output_map = estimator_to_output_map
|
|
||||||
|
|
||||||
self._outputs: list[list[Tensor]] | None = None
|
self._outputs: list[list[Tensor]] | None = None
|
||||||
self._metrics: list[list[dict[str, float]]] | None = None
|
self._metrics: list[list[dict[str, float]]] | None = None
|
||||||
|
|
||||||
self._batch_outputs_fn = partial(
|
|
||||||
self.trainer.get_batch_outputs,
|
|
||||||
batch_estimator_map=batch_estimator_map
|
|
||||||
)
|
|
||||||
self._batch_metrics_fn = partial(
|
|
||||||
self.trainer.get_batch_metrics,
|
|
||||||
batch_estimator_map=batch_estimator_map
|
|
||||||
)
|
|
||||||
|
|
||||||
self._data_tuples = None
|
self._data_tuples = None
|
||||||
|
|
||||||
@property
|
|
||||||
def outputs(self) -> list[list[Tensor]]:
|
|
||||||
if self._outputs is None:
|
|
||||||
self._outputs = [
|
|
||||||
list(map(self._batch_outputs_fn, loader))
|
|
||||||
for loader in self.dataloaders
|
|
||||||
]
|
|
||||||
return self._outputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def metrics(self) -> list[list[dict[str, float]]]:
|
|
||||||
if self._metrics is None:
|
|
||||||
self._metrics = [
|
|
||||||
list(map(self._batch_metrics_fn, loader))
|
|
||||||
for loader in self.dataloaders
|
|
||||||
]
|
|
||||||
return self._metrics
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
|
def data_tuples(self) -> list[tuple[Tensor, Tensor, str]]:
|
||||||
"""
|
"""
|
||||||
Produce data items; to be cached. Zip later with axes
|
Produce data items; to be cached. Zip later with axes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self.trainer.estimator.eval()
|
||||||
|
|
||||||
if self._data_tuples is not None:
|
if self._data_tuples is not None:
|
||||||
return self._data_tuples
|
return self._data_tuples
|
||||||
|
|
||||||
@@ -89,10 +64,17 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
for i, loader in enumerate(self.dataloaders):
|
for i, loader in enumerate(self.dataloaders):
|
||||||
label = self.dataloader_labels[i]
|
label = self.dataloader_labels[i]
|
||||||
|
|
||||||
batch = next(iter(loader))
|
actual = [
|
||||||
est_kwargs = self.batch_estimator_map(batch, self.trainer)
|
self.kw_to_actual(batch_kwargs).detach().cpu()
|
||||||
actual = self.estimator_to_output_map(est_kwargs)
|
for batch_kwargs in loader
|
||||||
output = self._batch_outputs_fn(batch)
|
]
|
||||||
|
actual = torch.cat([ai.reshape(*([*ai.shape]+[1])[:2]) for ai in actual])
|
||||||
|
|
||||||
|
output = [
|
||||||
|
self.trainer.estimator(**batch_kwargs)[0].detach().cpu()
|
||||||
|
for batch_kwargs in loader
|
||||||
|
]
|
||||||
|
output = torch.cat([oi.reshape(*([*oi.shape]+[1])[:2]) for oi in output])
|
||||||
|
|
||||||
data_tuples.append((actual, output, label))
|
data_tuples.append((actual, output, label))
|
||||||
|
|
||||||
@@ -219,6 +201,14 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
Note: transform samples in dataloader definitions beforehand if you
|
Note: transform samples in dataloader definitions beforehand if you
|
||||||
want to change data
|
want to change data
|
||||||
|
|
||||||
|
.. todo::
|
||||||
|
|
||||||
|
Merge in logic from general diagnostics, allowing collapse from
|
||||||
|
either dim and transposing.
|
||||||
|
|
||||||
|
Later: multi-trial error bars, or at least the ability to pass that
|
||||||
|
downstream
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
row_size:
|
row_size:
|
||||||
col_size:
|
col_size:
|
||||||
@@ -270,6 +260,12 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
return fig, axes
|
return fig, axes
|
||||||
|
|
||||||
|
# def plot_ordered(...): ...
|
||||||
|
# """
|
||||||
|
# Simple ordered view of output dimensions, with actual and output
|
||||||
|
# overlaid.
|
||||||
|
# """
|
||||||
|
|
||||||
def plot_actual_output(
|
def plot_actual_output(
|
||||||
self,
|
self,
|
||||||
row_size: int | float = 2,
|
row_size: int | float = 2,
|
||||||
@@ -462,3 +458,74 @@ class Plotter[B, K: EstimatorKwargs]:
|
|||||||
combine_dims=combine_dims,
|
combine_dims=combine_dims,
|
||||||
figure_kwargs=figure_kwargs,
|
figure_kwargs=figure_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def estimator_diagnostic(
|
||||||
|
self,
|
||||||
|
row_size: int | float = 2,
|
||||||
|
col_size: int | float = 4,
|
||||||
|
session_name: str | None = None,
|
||||||
|
combine_groups: bool = False,
|
||||||
|
combine_metrics: bool = False,
|
||||||
|
transpose_layout: bool = False,
|
||||||
|
figure_kwargs: SubplotsKwargs | None = None,
|
||||||
|
) -> tuple[plt.Figure, AxesArray]:
|
||||||
|
session_map = self.trainer._event_log
|
||||||
|
session_name = session_name or next(iter(session_map))
|
||||||
|
groups = session_map[session_name]
|
||||||
|
num_metrics = len(groups[next(iter(groups))])
|
||||||
|
# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
||||||
|
|
||||||
|
rows = 1 if combine_groups else len(groups)
|
||||||
|
cols = 1 if combine_metrics else num_metrics
|
||||||
|
if transpose_layout:
|
||||||
|
rows, cols = cols, rows
|
||||||
|
|
||||||
|
fig, axes = self._create_subplots(
|
||||||
|
rows=rows,
|
||||||
|
cols=cols,
|
||||||
|
row_size=row_size,
|
||||||
|
col_size=col_size,
|
||||||
|
figure_kwargs=figure_kwargs,
|
||||||
|
)
|
||||||
|
if transpose_layout:
|
||||||
|
axes = axes.T
|
||||||
|
|
||||||
|
for i, group_name in enumerate(groups):
|
||||||
|
axes_row = axes[0 if combine_groups else i]
|
||||||
|
group_metrics = groups[group_name]
|
||||||
|
|
||||||
|
for j, metric_name in enumerate(group_metrics):
|
||||||
|
ax = axes_row[0 if combine_metrics else j]
|
||||||
|
|
||||||
|
metric_dict = group_metrics[metric_name]
|
||||||
|
metric_data = np.array([
|
||||||
|
(k, np.mean(v)) for k, v in metric_dict.items()
|
||||||
|
])
|
||||||
|
|
||||||
|
if combine_groups and combine_metrics:
|
||||||
|
label = f"{group_name}-{metric_name}"
|
||||||
|
title_prefix = "all"
|
||||||
|
elif combine_groups:
|
||||||
|
label = group_name
|
||||||
|
title_prefix = metric_name
|
||||||
|
# elif combine_metrics:
|
||||||
|
else:
|
||||||
|
label = metric_name
|
||||||
|
title_prefix = group_name
|
||||||
|
# else:
|
||||||
|
# label = ""
|
||||||
|
# title_prefix = f"{group_name},{metric_name}"
|
||||||
|
|
||||||
|
ax.plot(
|
||||||
|
metric_data[:, 0],
|
||||||
|
metric_data[:, 1],
|
||||||
|
label=label,
|
||||||
|
# color=colors[j],
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(f"[{title_prefix}] Metrics over epochs")
|
||||||
|
ax.set_xlabel("epoch")
|
||||||
|
ax.set_ylabel("value")
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
return fig, axes
|
||||||
|
|||||||
@@ -1,5 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Core interface for training ``Estimators`` with ``Datasets``
|
Core interface for training ``Estimators`` with ``Datasets``
|
||||||
|
|
||||||
|
.. admonition:: Design of preview ``get_dataloaders()``
|
||||||
|
|
||||||
|
|
||||||
|
Note how much this method is doing, and the positivity in letting that be
|
||||||
|
more explicit elsewhere. The assignment of transforms to datasets before
|
||||||
|
wrapping as loaders is chief among these items, alongside the balancing and
|
||||||
|
splitting; I think those are hamfisted here to make it work with the old
|
||||||
|
setup, but I generally it's not consistent with the name "get dataloaders"
|
||||||
|
(i.e., and also balance and split and set transforms)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -7,48 +17,42 @@ import time
|
|||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Self
|
from typing import Any
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torch import cuda, Tensor
|
from torch import cuda, Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from trainlib.dataset import BatchedDataset
|
|
||||||
from trainlib.estimator import Estimator, EstimatorKwargs
|
from trainlib.estimator import Estimator, EstimatorKwargs
|
||||||
from trainlib.transform import Transform
|
from trainlib.utils.map import nested_defaultdict
|
||||||
from trainlib.utils.type import (
|
from trainlib.dataloader import EstimatorDataLoader
|
||||||
SplitKwargs,
|
|
||||||
LoaderKwargs,
|
|
||||||
BalanceKwargs,
|
|
||||||
)
|
|
||||||
from trainlib.utils.module import ModelWrapper
|
from trainlib.utils.module import ModelWrapper
|
||||||
|
from trainlib.utils.session import ensure_same_device
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Trainer[I, K: EstimatorKwargs]:
|
class Trainer[I, Kw: EstimatorKwargs]:
|
||||||
"""
|
"""
|
||||||
Training interface for optimizing parameters of ``Estimators`` with
|
Training interface for optimizing parameters of ``Estimators`` with
|
||||||
``Datasets``.
|
``Datasets``.
|
||||||
|
|
||||||
This class is generic to a dataset item type ``I`` and an estimator kwarg
|
This class is generic to a dataset item type ``I`` and an estimator kwarg
|
||||||
type ``K``. These are the two primary components ``Trainer`` objects need
|
type ``Kw``. These are the two primary components ``Trainer`` objects need
|
||||||
to coordinate: they ultimately rely on a provided map to ensure data items
|
to coordinate: they ultimately rely on a provided map to ensure data items
|
||||||
(type ``I``) from a dataset are appropriately routed as inputs to key
|
(type ``I``) from a dataset are appropriately routed as inputs to key
|
||||||
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
|
estimator methods (like ``forward()`` and ``loss()``), which accept inputs
|
||||||
of type ``K``.
|
of type ``Kw``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
estimator: Estimator[K],
|
estimator: Estimator[Kw],
|
||||||
device: str | None = None,
|
device: str | None = None,
|
||||||
chkpt_dir: str = "chkpt/",
|
chkpt_dir: str = "chkpt/",
|
||||||
tblog_dir: str = "tblog/",
|
tblog_dir: str = "tblog/",
|
||||||
@@ -93,91 +97,87 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
self.estimator = estimator
|
self.estimator = estimator
|
||||||
self.estimator.to(self.device)
|
self.estimator.to(self.device)
|
||||||
|
self._event_log = nested_defaultdict(4, list)
|
||||||
|
|
||||||
self.chkpt_dir = Path(chkpt_dir).resolve()
|
self.chkpt_dir = Path(chkpt_dir).resolve()
|
||||||
self.tblog_dir = Path(tblog_dir).resolve()
|
self.tblog_dir = Path(tblog_dir).resolve()
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self, resume: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Set initial tracking parameters for the primary training loop.
|
Set initial tracking parameters for the primary training loop.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
resume: if ``True``, just resets the stagnant epoch counter, with
|
||||||
|
the aims of continuing any existing training state under
|
||||||
|
resumed ``train()`` call. This should likely only be set when
|
||||||
|
training is continued on the same dataset and the goal is to
|
||||||
|
resume convergence loss-based scoring for a fresh set of
|
||||||
|
epochs. If even that element of the training loop should resume
|
||||||
|
(which should only happen if a training loop was interrupted or
|
||||||
|
a max epoch limit was reached), then this method shouldn't be
|
||||||
|
called at all between ``train()`` invocations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._epoch: int = 1
|
|
||||||
self._summary = defaultdict(lambda: defaultdict(dict))
|
|
||||||
self._event_log = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
|
|
||||||
|
|
||||||
self._val_loss = float("inf")
|
|
||||||
self._best_val_loss = float("inf")
|
|
||||||
self._stagnant_epochs = 0
|
self._stagnant_epochs = 0
|
||||||
self._best_model_state_dict: dict[str, Any] = {}
|
|
||||||
|
if not resume:
|
||||||
|
self._epoch: int = 0
|
||||||
|
self._summary = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
self._conv_loss = float("inf")
|
||||||
|
self._best_conv_loss = float("inf")
|
||||||
|
self._best_conv_epoch = 0
|
||||||
|
self._best_model_state_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
def _train_epoch(
|
def _train_epoch(
|
||||||
self,
|
self,
|
||||||
train_loader: DataLoader,
|
loader: EstimatorDataLoader[Any, Kw],
|
||||||
batch_estimator_map: Callable[[I, Self], K],
|
|
||||||
optimizers: tuple[Optimizer, ...],
|
optimizers: tuple[Optimizer, ...],
|
||||||
max_grad_norm: float | None = None,
|
max_grad_norm: float | None = None,
|
||||||
|
progress_bar: tqdm | None = None,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Train the estimator for a single epoch.
|
Train the estimator for a single epoch.
|
||||||
|
|
||||||
.. admonition:: On summary writers
|
|
||||||
|
|
||||||
Estimators can have several optimizers, and therefore can emit
|
|
||||||
several losses. This is a fairly unique case, but it's needed when
|
|
||||||
we want to optimize particular parameters in a particular order
|
|
||||||
(as in multi-model architectures, e.g., GANs). Point being: we
|
|
||||||
always iterate over optimizers/losses, even in the common case
|
|
||||||
where there's just a single value, and we index collections across
|
|
||||||
batches accordingly.
|
|
||||||
|
|
||||||
A few of the trackers, with the same size as the number of
|
|
||||||
optimizers:
|
|
||||||
|
|
||||||
- ``train_loss_sums``: tracks loss sums across all batches for the
|
|
||||||
epoch, used to update the loop preview text after each batch with
|
|
||||||
the current average loss
|
|
||||||
- ``train_loss_items``: collects current batch losses, recorded by
|
|
||||||
the TB writer
|
|
||||||
|
|
||||||
If there are ``M`` optimizers/losses, we log ``M`` loss terms to
|
|
||||||
the TB writer after each *batch* (not epoch). We could aggregate at
|
|
||||||
an epoch level, but parameter updates take place after each batch,
|
|
||||||
so large model changes can occur over the course of an epoch
|
|
||||||
(whereas the model remains the same over the course batch evals).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss_sums = []
|
loss_sums = []
|
||||||
self.estimator.train()
|
self.estimator.train()
|
||||||
with tqdm(train_loader, unit="batch") as batches:
|
for i, batch_kwargs in enumerate(loader):
|
||||||
for i, batch_data in enumerate(batches):
|
batch_kwargs = ensure_same_device(batch_kwargs, self.device)
|
||||||
est_kwargs = batch_estimator_map(batch_data, self)
|
losses = self.estimator.loss(**batch_kwargs)
|
||||||
losses = self.estimator.loss(**est_kwargs)
|
|
||||||
|
|
||||||
for o_idx, (loss, optimizer) in enumerate(
|
for o_idx, (loss, optimizer) in enumerate(
|
||||||
zip(losses, optimizers, strict=True)
|
zip(losses, optimizers, strict=True)
|
||||||
):
|
):
|
||||||
if len(loss_sums) <= o_idx:
|
if len(loss_sums) <= o_idx:
|
||||||
loss_sums.append(0.0)
|
loss_sums.append(0.0)
|
||||||
loss_sums[o_idx] += loss.item()
|
loss_sums[o_idx] += loss.item()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# clip gradients for optimizer's parameters
|
# clip gradients for optimizer's parameters
|
||||||
if max_grad_norm is not None:
|
if max_grad_norm is not None:
|
||||||
clip_grad_norm_(
|
clip_grad_norm_(
|
||||||
self._get_optimizer_parameters(optimizer),
|
self._get_optimizer_parameters(optimizer),
|
||||||
max_norm=max_grad_norm
|
max_norm=max_grad_norm
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# set loop loss to running average (reducing if multi-loss)
|
# set loop loss to running average (reducing if multi-loss)
|
||||||
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
||||||
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
|
||||||
|
if progress_bar:
|
||||||
|
progress_bar.update(1)
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
epoch=self._epoch,
|
||||||
|
mode="opt",
|
||||||
|
data="train",
|
||||||
|
loss=f"{loss_avg:8.2f}",
|
||||||
|
)
|
||||||
|
|
||||||
# step estimator hyperparam schedules
|
# step estimator hyperparam schedules
|
||||||
self.estimator.epoch_step()
|
self.estimator.epoch_step()
|
||||||
@@ -186,9 +186,9 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
def _eval_epoch(
|
def _eval_epoch(
|
||||||
self,
|
self,
|
||||||
loader: DataLoader,
|
loader: EstimatorDataLoader[Any, Kw],
|
||||||
batch_estimator_map: Callable[[I, Self], K],
|
label: str,
|
||||||
loader_label: str,
|
progress_bar: tqdm | None = None,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Perform and record validation scores for a single epoch.
|
Perform and record validation scores for a single epoch.
|
||||||
@@ -216,55 +216,65 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
loss_sums = []
|
loss_sums = []
|
||||||
self.estimator.eval()
|
self.estimator.eval()
|
||||||
with tqdm(loader, unit="batch") as batches:
|
for i, batch_kwargs in enumerate(loader):
|
||||||
for i, batch_data in enumerate(batches):
|
batch_kwargs = ensure_same_device(batch_kwargs, self.device)
|
||||||
est_kwargs = batch_estimator_map(batch_data, self)
|
losses = self.estimator.loss(**batch_kwargs)
|
||||||
losses = self.estimator.loss(**est_kwargs)
|
|
||||||
|
|
||||||
# one-time logging
|
|
||||||
if self._epoch == 0:
|
|
||||||
self._writer.add_graph(
|
|
||||||
ModelWrapper(self.estimator), est_kwargs
|
|
||||||
)
|
|
||||||
# once-per-epoch logging
|
|
||||||
if i == 0:
|
|
||||||
self.estimator.epoch_write(
|
|
||||||
self._writer,
|
|
||||||
step=self._epoch,
|
|
||||||
group=loader_label,
|
|
||||||
**est_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_items = []
|
# once-per-session logging
|
||||||
for o_idx, loss in enumerate(losses):
|
if self._epoch == 0 and i == 0:
|
||||||
if len(loss_sums) <= o_idx:
|
self._writer.add_graph(
|
||||||
loss_sums.append(0.0)
|
ModelWrapper(self.estimator), batch_kwargs
|
||||||
|
)
|
||||||
|
# once-per-epoch logging
|
||||||
|
if i == 0:
|
||||||
|
self.estimator.epoch_write(
|
||||||
|
self._writer,
|
||||||
|
step=self._epoch,
|
||||||
|
group=label,
|
||||||
|
**batch_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
loss_item = loss.item()
|
loss_items = []
|
||||||
loss_sums[o_idx] += loss_item
|
for o_idx, loss in enumerate(losses):
|
||||||
loss_items.append(loss_item)
|
if len(loss_sums) <= o_idx:
|
||||||
|
loss_sums.append(0.0)
|
||||||
|
|
||||||
# set loop loss to running average (reducing if multi-loss)
|
loss_item = loss.item()
|
||||||
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
loss_sums[o_idx] += loss_item
|
||||||
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
loss_items.append(loss_item)
|
||||||
|
|
||||||
# log individual loss terms after each batch
|
# set loop loss to running average (reducing if multi-loss)
|
||||||
for o_idx, loss_item in enumerate(loss_items):
|
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
||||||
self._log_event(loader_label, f"loss_{o_idx}", loss_item)
|
|
||||||
|
|
||||||
# log metrics for batch
|
if progress_bar:
|
||||||
estimator_metrics = self.estimator.metrics(**est_kwargs)
|
progress_bar.update(1)
|
||||||
for metric_name, metric_value in estimator_metrics.items():
|
progress_bar.set_postfix(
|
||||||
self._log_event(loader_label, metric_name, metric_value)
|
epoch=self._epoch,
|
||||||
|
mode="eval",
|
||||||
|
data=label,
|
||||||
|
loss=f"{loss_avg:8.2f}",
|
||||||
|
)
|
||||||
|
|
||||||
return loss_sums
|
# log individual loss terms after each batch
|
||||||
|
for o_idx, loss_item in enumerate(loss_items):
|
||||||
|
self._log_event(label, f"loss_{o_idx}", loss_item)
|
||||||
|
|
||||||
|
# log metrics for batch
|
||||||
|
estimator_metrics = self.estimator.metrics(**batch_kwargs)
|
||||||
|
for metric_name, metric_value in estimator_metrics.items():
|
||||||
|
self._log_event(label, metric_name, metric_value)
|
||||||
|
|
||||||
|
avg_losses = [loss_sum / (i+1) for loss_sum in loss_sums]
|
||||||
|
|
||||||
|
return avg_losses
|
||||||
|
|
||||||
def _eval_loaders(
|
def _eval_loaders(
|
||||||
self,
|
self,
|
||||||
loaders: list[DataLoader],
|
train_loader: EstimatorDataLoader[Any, Kw],
|
||||||
batch_estimator_map: Callable[[I, Self], K],
|
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
|
||||||
loader_labels: list[str],
|
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
|
||||||
) -> dict[str, list[float]]:
|
progress_bar: tqdm | None = None,
|
||||||
|
) -> tuple[list[float], list[float] | None, *list[float]]:
|
||||||
"""
|
"""
|
||||||
Evaluate estimator over each provided dataloader.
|
Evaluate estimator over each provided dataloader.
|
||||||
|
|
||||||
@@ -278,42 +288,73 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
|
``get_batch_outputs()`` or ``get_batch_metrics()`` while iterating over
|
||||||
batches. This will have no internal side effects and provides much more
|
batches. This will have no internal side effects and provides much more
|
||||||
information (just aggregated losses are provided here).
|
information (just aggregated losses are provided here).
|
||||||
|
|
||||||
|
.. admonition:: On epoch counting
|
||||||
|
|
||||||
|
Epoch counts start at 0 to allow for a sensible place to benchmark
|
||||||
|
the initial (potentially untrained/pre-trained) model before any
|
||||||
|
training data is seen. In the train loop, we increment the epoch
|
||||||
|
immediately, and all logging happens under the epoch value that's
|
||||||
|
set at the start of the iteration (rather than incrementing at the
|
||||||
|
end). Before beginning an additional training iteration, the
|
||||||
|
convergence condition in the ``while`` is effectively checking what
|
||||||
|
happened during the last epoch (the counter has not yet been
|
||||||
|
incremented); if no convergence, we begin again. (This is only
|
||||||
|
being noted because the epoch counting was previously quite
|
||||||
|
different: indexing started at ``1``, we incremented at the end of
|
||||||
|
the loop, and we didn't evaluate the model before the loop began.
|
||||||
|
This affects how we interpret plots and TensorBoard records, for
|
||||||
|
instance, so it's useful to spell out the approach clearly
|
||||||
|
somewhere given the many possible design choices here.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return {
|
train_loss = self._eval_epoch(train_loader, "train", progress_bar)
|
||||||
label: self._eval_epoch(loader, batch_estimator_map, label)
|
|
||||||
for loader, label in zip(loaders, loader_labels, strict=True)
|
|
||||||
}
|
|
||||||
|
|
||||||
def train[B](
|
val_loss = None
|
||||||
|
if val_loader is not None:
|
||||||
|
val_loss = self._eval_epoch(val_loader, "val", progress_bar)
|
||||||
|
|
||||||
|
aux_loaders = aux_loaders or []
|
||||||
|
aux_losses = [
|
||||||
|
self._eval_epoch(aux_loader, f"aux{i}", progress_bar)
|
||||||
|
for i, aux_loader in enumerate(aux_loaders)
|
||||||
|
]
|
||||||
|
|
||||||
|
return train_loss, val_loss, *aux_losses
|
||||||
|
|
||||||
|
def train(
|
||||||
self,
|
self,
|
||||||
dataset: BatchedDataset[..., ..., I],
|
train_loader: EstimatorDataLoader[Any, Kw],
|
||||||
batch_estimator_map: Callable[[B, Self], K],
|
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
|
||||||
|
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
|
||||||
|
*,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
max_grad_norm: float | None = None,
|
max_grad_norm: float | None = None,
|
||||||
max_epochs: int = 10,
|
max_epochs: int = 10,
|
||||||
stop_after_epochs: int = 5,
|
stop_after_epochs: int = 5,
|
||||||
batch_size: int = 256,
|
|
||||||
val_frac: float = 0.1,
|
|
||||||
train_transform: Transform | None = None,
|
|
||||||
val_transform: Transform | None = None,
|
|
||||||
dataset_split_kwargs: SplitKwargs | None = None,
|
|
||||||
dataset_balance_kwargs: BalanceKwargs | None = None,
|
|
||||||
dataloader_kwargs: LoaderKwargs | None = None,
|
|
||||||
summarize_every: int = 1,
|
summarize_every: int = 1,
|
||||||
chkpt_every: int = 1,
|
chkpt_every: int = 1,
|
||||||
resume_latest: bool = False,
|
|
||||||
session_name: str | None = None,
|
session_name: str | None = None,
|
||||||
summary_writer: SummaryWriter | None = None,
|
summary_writer: SummaryWriter | None = None,
|
||||||
aux_loaders: list[DataLoader] | None = None,
|
|
||||||
aux_loader_labels: list[str] | None = None,
|
|
||||||
) -> Estimator:
|
) -> Estimator:
|
||||||
"""
|
"""
|
||||||
TODO: consider making the dataloader ``collate_fn`` an explicit
|
.. todo::
|
||||||
parameter with a type signature that reflects ``B``, connecting the
|
|
||||||
``batch_estimator_map`` somewhere. Might also re-type a ``DataLoader``
|
- consider making the dataloader ``collate_fn`` an explicit
|
||||||
in-house to allow a generic around ``B``
|
parameter with a type signature that reflects ``B``, connecting
|
||||||
|
the ``batch_estimator_map`` somewhere. Might also re-type a
|
||||||
|
``DataLoader`` in-house to allow a generic around ``B``
|
||||||
|
- Rework the validation specification. Accept something like a
|
||||||
|
"validate_with" parameter, or perhaps just move entirely to
|
||||||
|
accepting a dataloader list, label list. You might then also need
|
||||||
|
a "train_with," and you could set up sensible defaults so you
|
||||||
|
basically have the same interaction as now. The "problem" is you
|
||||||
|
always need a train set, and there's some clearly dependent logic
|
||||||
|
on a val set, but you don't *need* val, so this should be
|
||||||
|
slightly reworked (and the more general, *probably* the better in
|
||||||
|
this case, given I want to plug into the Plotter with possibly
|
||||||
|
several purely eval sets over the model training lifetime).
|
||||||
|
|
||||||
Note: this method attempts to implement a general scheme for passing
|
Note: this method attempts to implement a general scheme for passing
|
||||||
needed items to the estimator's loss function from the dataloader. The
|
needed items to the estimator's loss function from the dataloader. The
|
||||||
@@ -355,7 +396,7 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
This function should map from batches - which *may* be item
|
This function should map from batches - which *may* be item
|
||||||
shaped, i.e., have an ``I`` skeleton, even if stacked items may be
|
shaped, i.e., have an ``I`` skeleton, even if stacked items may be
|
||||||
different on the inside - into estimator keyword arguments (type
|
different on the inside - into estimator keyword arguments (type
|
||||||
``K``). Collation behavior from a DataLoader (which can be
|
``Kw``). Collation behavior from a DataLoader (which can be
|
||||||
customized) doesn't consistently yield a known type shape, however,
|
customized) doesn't consistently yield a known type shape, however,
|
||||||
so it's not appropriate to use ``I`` as the callable param type.
|
so it's not appropriate to use ``I`` as the callable param type.
|
||||||
|
|
||||||
@@ -416,85 +457,98 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
dataset
|
dataset
|
||||||
val_split_frac: fraction of dataset to use for validation
|
val_split_frac: fraction of dataset to use for validation
|
||||||
chkpt_every: how often model checkpoints should be saved
|
chkpt_every: how often model checkpoints should be saved
|
||||||
resume_latest: resume training from the latest available checkpoint
|
|
||||||
in the `chkpt_dir`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info("> Begin train loop:")
|
logger.info("> Begin train loop:")
|
||||||
logger.info(f"| > {lr=}")
|
logger.info(f"| > {lr=}")
|
||||||
logger.info(f"| > {eps=}")
|
logger.info(f"| > {eps=}")
|
||||||
logger.info(f"| > {max_epochs=}")
|
logger.info(f"| > {max_epochs=}")
|
||||||
logger.info(f"| > {batch_size=}")
|
|
||||||
logger.info(f"| > {val_frac=}")
|
|
||||||
logger.info(f"| > {chkpt_every=}")
|
logger.info(f"| > {chkpt_every=}")
|
||||||
logger.info(f"| > {resume_latest=}")
|
|
||||||
logger.info(f"| > with device: {self.device}")
|
logger.info(f"| > with device: {self.device}")
|
||||||
logger.info(f"| > core count: {os.cpu_count()}")
|
logger.info(f"| > core count: {os.cpu_count()}")
|
||||||
|
|
||||||
self._session_name = session_name or str(int(time.time()))
|
self._session_name = session_name or str(int(time.time()))
|
||||||
tblog_path = Path(self.tblog_dir, self._session_name)
|
tblog_path = Path(self.tblog_dir, self._session_name)
|
||||||
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
||||||
|
progress_bar = tqdm(train_loader, unit="batch")
|
||||||
aux_loaders = aux_loaders or []
|
|
||||||
aux_loader_labels = aux_loader_labels or []
|
|
||||||
|
|
||||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
|
||||||
train_loader, val_loader = self.get_dataloaders(
|
|
||||||
dataset,
|
|
||||||
batch_size,
|
|
||||||
val_frac=val_frac,
|
|
||||||
train_transform=train_transform,
|
|
||||||
val_transform=val_transform,
|
|
||||||
dataset_split_kwargs=dataset_split_kwargs,
|
|
||||||
dataset_balance_kwargs=dataset_balance_kwargs,
|
|
||||||
dataloader_kwargs=dataloader_kwargs,
|
|
||||||
)
|
|
||||||
loaders = [train_loader, val_loader, *aux_loaders]
|
|
||||||
loader_labels = ["train", "val", *aux_loader_labels]
|
|
||||||
|
|
||||||
# evaluate model on dataloaders once before training starts
|
# evaluate model on dataloaders once before training starts
|
||||||
self._eval_loaders(loaders, batch_estimator_map, loader_labels)
|
train_loss, val_loss, *_ = self._eval_loaders(
|
||||||
|
train_loader, val_loader, aux_loaders, progress_bar
|
||||||
|
)
|
||||||
|
conv_loss = val_loss if val_loss else train_loss
|
||||||
|
self._conv_loss = sum(conv_loss) / len(conv_loss)
|
||||||
|
|
||||||
while self._epoch <= max_epochs and not self._converged(
|
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||||
self._epoch, stop_after_epochs
|
|
||||||
):
|
while self._epoch < max_epochs and not self._converged(stop_after_epochs):
|
||||||
train_frac = f"{self._epoch}/{max_epochs}"
|
self._epoch += 1
|
||||||
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
#train_frac = f"{self._epoch}/{max_epochs}"
|
||||||
print(f"Training epoch {train_frac}...")
|
#stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||||
print(f"Stagnant epochs {stag_frac}...")
|
#print(f"Training epoch {train_frac}...")
|
||||||
|
#print(f"Stagnant epochs {stag_frac}...")
|
||||||
|
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
self._train_epoch(
|
self._train_epoch(
|
||||||
train_loader,
|
train_loader,
|
||||||
batch_estimator_map,
|
|
||||||
optimizers,
|
optimizers,
|
||||||
max_grad_norm,
|
max_grad_norm,
|
||||||
|
progress_bar=progress_bar
|
||||||
)
|
)
|
||||||
epoch_end_time = time.time() - epoch_start_time
|
epoch_end_time = time.time() - epoch_start_time
|
||||||
self._log_event("train", "epoch_duration", epoch_end_time)
|
self._log_event("train", "epoch_duration", epoch_end_time)
|
||||||
|
|
||||||
loss_sum_map = self._eval_loaders(
|
train_loss, val_loss, *_ = self._eval_loaders(
|
||||||
loaders,
|
train_loader, val_loader, aux_loaders, progress_bar
|
||||||
batch_estimator_map,
|
|
||||||
loader_labels,
|
|
||||||
)
|
)
|
||||||
val_loss_sums = loss_sum_map["val"]
|
# determine loss to use for measuring convergence
|
||||||
self._val_loss = sum(val_loss_sums) / len(val_loader)
|
conv_loss = val_loss if val_loss else train_loss
|
||||||
|
self._conv_loss = sum(conv_loss) / len(conv_loss)
|
||||||
|
|
||||||
if self._epoch % summarize_every == 0:
|
if self._epoch % summarize_every == 0:
|
||||||
self._summarize()
|
self._summarize()
|
||||||
if self._epoch % chkpt_every == 0:
|
if self._epoch % chkpt_every == 0:
|
||||||
self.save_model()
|
self.save_model()
|
||||||
self._epoch += 1
|
|
||||||
|
|
||||||
return self.estimator
|
return self.estimator
|
||||||
|
|
||||||
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
|
def _converged(self, stop_after_epochs: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model has converged.
|
||||||
|
|
||||||
|
This method looks at the current "convergence loss" (validation-based
|
||||||
|
if a val set is provided to ``train()``, otherwise the training loss is
|
||||||
|
used), checking if it's the best yet recorded, incrementing the
|
||||||
|
stagnancy count if not. Convergence is asserted only if the number of
|
||||||
|
stagnant epochs exceeds ``stop_after_epochs``.
|
||||||
|
|
||||||
|
.. admonition:: Evaluation order
|
||||||
|
|
||||||
|
Convergence losses are recorded before the first training update,
|
||||||
|
so initial model states are appropriately benchmarked by the time
|
||||||
|
``_converged()`` is invoked.
|
||||||
|
|
||||||
|
If resuming training on the same dataset, one might expect only to
|
||||||
|
reset the stagnant epoch counter: you'll resume from the last
|
||||||
|
epoch, estimator state, and best seen loss, while allowed
|
||||||
|
``stop_after_epochs`` more chances for better validation.
|
||||||
|
|
||||||
|
If picking up training on a new dataset, even a training+validation
|
||||||
|
setting, resetting the best seen loss and best model state is
|
||||||
|
needed: you can't reliably compare the existing stats under new
|
||||||
|
data. It's somewhat ambiguous whether ``epoch`` absolutely must be
|
||||||
|
reset; you could continue logging metrics under the same named
|
||||||
|
session. But best practices would suggest restarting the epoch
|
||||||
|
count and have events logged under a new session heading when data
|
||||||
|
change.
|
||||||
|
"""
|
||||||
|
|
||||||
converged = False
|
converged = False
|
||||||
|
|
||||||
if epoch == 1 or self._val_loss < self._best_val_loss:
|
if self._conv_loss < self._best_conv_loss:
|
||||||
self._best_val_loss = self._val_loss
|
|
||||||
self._stagnant_epochs = 0
|
self._stagnant_epochs = 0
|
||||||
|
self._best_conv_loss = self._conv_loss
|
||||||
|
self._best_conv_epoch = self._epoch
|
||||||
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
self._best_model_state_dict = deepcopy(self.estimator.state_dict())
|
||||||
else:
|
else:
|
||||||
self._stagnant_epochs += 1
|
self._stagnant_epochs += 1
|
||||||
@@ -505,110 +559,25 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
|
|
||||||
return converged
|
return converged
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_dataloaders(
|
|
||||||
dataset: BatchedDataset,
|
|
||||||
batch_size: int,
|
|
||||||
val_frac: float = 0.1,
|
|
||||||
train_transform: Transform | None = None,
|
|
||||||
val_transform: Transform | None = None,
|
|
||||||
dataset_split_kwargs: SplitKwargs | None = None,
|
|
||||||
dataset_balance_kwargs: BalanceKwargs | None = None,
|
|
||||||
dataloader_kwargs: LoaderKwargs | None = None,
|
|
||||||
) -> tuple[DataLoader, DataLoader]:
|
|
||||||
"""
|
|
||||||
Create training and validation dataloaders for the provided dataset.
|
|
||||||
|
|
||||||
.. todo::
|
|
||||||
|
|
||||||
Decide on policy for empty val dataloaders
|
|
||||||
"""
|
|
||||||
|
|
||||||
if dataset_split_kwargs is None:
|
|
||||||
dataset_split_kwargs = {}
|
|
||||||
|
|
||||||
if dataset_balance_kwargs is not None:
|
|
||||||
dataset.balance(**dataset_balance_kwargs)
|
|
||||||
|
|
||||||
if val_frac <= 0:
|
|
||||||
dataset.post_transform = train_transform
|
|
||||||
train_loader_kwargs: LoaderKwargs = {
|
|
||||||
"batch_size": min(batch_size, len(dataset)),
|
|
||||||
"num_workers": 0,
|
|
||||||
"shuffle": True,
|
|
||||||
}
|
|
||||||
if dataloader_kwargs is not None:
|
|
||||||
train_loader_kwargs: LoaderKwargs = {
|
|
||||||
**train_loader_kwargs,
|
|
||||||
**dataloader_kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
DataLoader(dataset, **train_loader_kwargs),
|
|
||||||
DataLoader(Dataset())
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset, val_dataset = dataset.split(
|
|
||||||
[1 - val_frac, val_frac],
|
|
||||||
**dataset_split_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dataset.split() returns light Subset objects of shallow copies of the
|
|
||||||
# underlying dataset; can change the transform attribute of both splits
|
|
||||||
# w/o overwriting
|
|
||||||
train_dataset.post_transform = train_transform
|
|
||||||
val_dataset.post_transform = val_transform
|
|
||||||
|
|
||||||
train_loader_kwargs: LoaderKwargs = {
|
|
||||||
"batch_size": min(batch_size, len(train_dataset)),
|
|
||||||
"num_workers": 0,
|
|
||||||
"shuffle": True,
|
|
||||||
}
|
|
||||||
val_loader_kwargs: LoaderKwargs = {
|
|
||||||
"batch_size": min(batch_size, len(val_dataset)),
|
|
||||||
"num_workers": 0,
|
|
||||||
"shuffle": True, # shuffle to prevent homogeneous val batches
|
|
||||||
}
|
|
||||||
|
|
||||||
if dataloader_kwargs is not None:
|
|
||||||
train_loader_kwargs = {**train_loader_kwargs, **dataloader_kwargs}
|
|
||||||
val_loader_kwargs = {**val_loader_kwargs, **dataloader_kwargs}
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, **train_loader_kwargs)
|
|
||||||
val_loader = DataLoader(val_dataset, **val_loader_kwargs)
|
|
||||||
|
|
||||||
return train_loader, val_loader
|
|
||||||
|
|
||||||
def _summarize(self) -> None:
|
def _summarize(self) -> None:
|
||||||
"""
|
"""
|
||||||
Flush the training summary to the TensorBoard summary writer.
|
Flush the training summary to the TensorBoard summary writer and print
|
||||||
|
metrics for the current epoch.
|
||||||
.. note:: Possibly undesirable behavior
|
|
||||||
|
|
||||||
Currently, this method aggregates metrics for the epoch summary
|
|
||||||
across all logged items *in between summarize calls*. For instance,
|
|
||||||
if I'm logging every 10 epochs, the stats at epoch=10 are actually
|
|
||||||
averages from epochs 1-10.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
epoch_values = defaultdict(lambda: defaultdict(list))
|
|
||||||
for group, records in self._summary.items():
|
|
||||||
for name, steps in records.items():
|
|
||||||
for step, value in steps.items():
|
|
||||||
self._writer.add_scalar(f"{group}-{name}", value, step)
|
|
||||||
if step == self._epoch:
|
|
||||||
epoch_values[group][name].append(value)
|
|
||||||
|
|
||||||
print(f"==== Epoch [{self._epoch}] summary ====")
|
print(f"==== Epoch [{self._epoch}] summary ====")
|
||||||
for group, records in epoch_values.items():
|
for (group, name), epoch_map in self._summary.items():
|
||||||
for name, values in records.items():
|
for epoch, values in epoch_map.items():
|
||||||
mean_value = torch.tensor(values).mean().item()
|
# compute average over batch items recorded for the epoch
|
||||||
print(
|
mean = torch.tensor(values).mean().item()
|
||||||
f"> ({len(values)}) [{group}] {name} :: {mean_value:.2f}"
|
self._writer.add_scalar(f"{group}-{name}", mean, epoch)
|
||||||
)
|
if epoch == self._epoch:
|
||||||
|
print(
|
||||||
|
f"> ({len(values)}) [{group}] {name} :: {mean:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
self._writer.flush()
|
self._writer.flush()
|
||||||
self._summary = defaultdict(lambda: defaultdict(dict))
|
self._summary = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
def _get_optimizer_parameters(
|
def _get_optimizer_parameters(
|
||||||
self,
|
self,
|
||||||
@@ -622,33 +591,10 @@ class Trainer[I, K: EstimatorKwargs]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _log_event(self, group: str, name: str, value: float) -> None:
|
def _log_event(self, group: str, name: str, value: float) -> None:
|
||||||
self._summary[group][name][self._epoch] = value
|
session, epoch = self._session_name, self._epoch
|
||||||
self._event_log[self._session_name][group][name][self._epoch] = value
|
|
||||||
|
|
||||||
def get_batch_outputs[B](
|
self._summary[group, name][epoch].append(value)
|
||||||
self,
|
self._event_log[session][group][name][epoch].append(value)
|
||||||
batch: B,
|
|
||||||
batch_estimator_map: Callable[[B, Self], K],
|
|
||||||
) -> Tensor:
|
|
||||||
self.estimator.eval()
|
|
||||||
|
|
||||||
est_kwargs = batch_estimator_map(batch, self)
|
|
||||||
output = self.estimator(**est_kwargs)[0]
|
|
||||||
output = output.detach().cpu()
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def get_batch_metrics[B](
|
|
||||||
self,
|
|
||||||
batch: B,
|
|
||||||
batch_estimator_map: Callable[[B, Self], K],
|
|
||||||
) -> dict[str, float]:
|
|
||||||
self.estimator.eval()
|
|
||||||
|
|
||||||
est_kwargs = batch_estimator_map(batch, self)
|
|
||||||
metrics = self.estimator.metrics(**est_kwargs)
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
def save_model(self) -> None:
|
def save_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
10
trainlib/utils/map.py
Normal file
10
trainlib/utils/map.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
def nested_defaultdict(
|
||||||
|
depth: int,
|
||||||
|
final: type = dict,
|
||||||
|
) -> defaultdict:
|
||||||
|
if depth == 1:
|
||||||
|
return defaultdict(final)
|
||||||
|
return defaultdict(lambda: nested_defaultdict(depth - 1, final))
|
||||||
9
trainlib/utils/op.py
Normal file
9
trainlib/utils/op.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def r2_score(y: Tensor, y_hat: Tensor) -> Tensor:
|
||||||
|
ss_res = ((y - y_hat)**2).sum()
|
||||||
|
ss_tot = ((y - y.mean())**2).sum()
|
||||||
|
r2 = 1 - ss_res / ss_tot
|
||||||
|
|
||||||
|
return r2
|
||||||
@@ -3,6 +3,7 @@ import random
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.utils import _pytree as pytree
|
||||||
|
|
||||||
|
|
||||||
def seed_all_backends(seed: int | Tensor | None = None) -> None:
|
def seed_all_backends(seed: int | Tensor | None = None) -> None:
|
||||||
@@ -19,3 +20,9 @@ def seed_all_backends(seed: int | Tensor | None = None) -> None:
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
def ensure_same_device[T](tree: T, device: str) -> T:
|
||||||
|
return pytree.tree_map(
|
||||||
|
lambda x: x.to(device) if isinstance(x, torch.Tensor) else x,
|
||||||
|
tree,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from colorama import Style
|
from colorama import Style
|
||||||
|
|
||||||
|
camel2snake_regex: re.Pattern[str] = re.compile(
|
||||||
|
r"(?<!^)(?=[A-Z][a-z])|(?<=[a-z])(?=[A-Z])"
|
||||||
|
)
|
||||||
|
|
||||||
|
def camel_to_snake(text: str) -> str:
|
||||||
|
return camel2snake_regex.sub("_", text).lower()
|
||||||
|
|
||||||
def color_text(text: str, *colorama_args: Any) -> str:
|
def color_text(text: str, *colorama_args: Any) -> str:
|
||||||
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"
|
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from trainlib.dataset import BatchedDataset
|
|||||||
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
|
type AxesArray = np.ndarray[tuple[int, int], np.dtype[np.object_]]
|
||||||
|
|
||||||
class LoaderKwargs(TypedDict, total=False):
|
class LoaderKwargs(TypedDict, total=False):
|
||||||
batch_size: int
|
batch_size: int | None
|
||||||
shuffle: bool
|
shuffle: bool
|
||||||
sampler: Sampler | Iterable | None
|
sampler: Sampler | Iterable | None
|
||||||
batch_sampler: Sampler[list] | Iterable[list] | None
|
batch_sampler: Sampler[list] | Iterable[list] | None
|
||||||
@@ -61,8 +61,8 @@ class SubplotsKwargs(TypedDict, total=False):
|
|||||||
squeeze: bool
|
squeeze: bool
|
||||||
width_ratios: Sequence[float]
|
width_ratios: Sequence[float]
|
||||||
height_ratios: Sequence[float]
|
height_ratios: Sequence[float]
|
||||||
subplot_kw: dict[str, ...]
|
subplot_kw: dict[str, Any]
|
||||||
gridspec_kw: dict[str, ...]
|
gridspec_kw: dict[str, Any]
|
||||||
figsize: tuple[float, float]
|
figsize: tuple[float, float]
|
||||||
dpi: float
|
dpi: float
|
||||||
layout: str
|
layout: str
|
||||||
|
|||||||
Reference in New Issue
Block a user