Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ba0c804d5e |
29
example/example.json
Normal file
29
example/example.json
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"estimator_name": "mlp",
|
||||||
|
"dataset_name": "random_xy_dataset",
|
||||||
|
"dataloader_name": "supervised_data_loader",
|
||||||
|
"estimator_kwargs": {
|
||||||
|
"input_dim": 4,
|
||||||
|
"output_dim": 2
|
||||||
|
},
|
||||||
|
"dataset_kwargs": {
|
||||||
|
"num_samples": 100000,
|
||||||
|
"preload": true,
|
||||||
|
"input_dim": 4,
|
||||||
|
"output_dim": 2
|
||||||
|
},
|
||||||
|
"dataset_split_fracs": {
|
||||||
|
"train": 0.4,
|
||||||
|
"val": 0.3,
|
||||||
|
"aux": [0.3]
|
||||||
|
},
|
||||||
|
"dataloader_kwargs": {
|
||||||
|
"batch_size": 16
|
||||||
|
},
|
||||||
|
"train_kwargs": {
|
||||||
|
"summarize_every": 20,
|
||||||
|
"max_epochs": 100,
|
||||||
|
"stop_after_epochs": 100
|
||||||
|
},
|
||||||
|
"load_only": false
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "trainlib"
|
name = "trainlib"
|
||||||
version = "0.2.0"
|
version = "0.3.0"
|
||||||
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
description = "Minimal framework for ML modeling. Supports advanced dataset operations and streamlined training."
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
18
trainlib/__main__.py
Normal file
18
trainlib/__main__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from trainlib.cli import create_parser
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = create_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# skim off log level to handle higher-level option
|
||||||
|
if hasattr(args, "log_level") and args.log_level is not None:
|
||||||
|
logging.basicConfig(level=args.log_level)
|
||||||
|
|
||||||
|
args.func(args) if "func" in args else parser.print_help()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
26
trainlib/cli/__init__.py
Normal file
26
trainlib/cli/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from trainlib.cli import train
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="trainlib cli",
|
||||||
|
# formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
type=int,
|
||||||
|
metavar="int",
|
||||||
|
choices=[10, 20, 30, 40, 50],
|
||||||
|
help="Log level: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(help="subcommand help")
|
||||||
|
train.register_parser(subparsers)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
164
trainlib/cli/train.py
Normal file
164
trainlib/cli/train.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
from typing import Any
|
||||||
|
from argparse import _SubParsersAction
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from trainlib.trainer import Trainer
|
||||||
|
from trainlib.datasets import dataset_map
|
||||||
|
from trainlib.estimator import Estimator
|
||||||
|
from trainlib.estimators import estimator_map
|
||||||
|
from trainlib.dataloaders import dataloader_map
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_run() -> None:
|
||||||
|
# prepare cuda memory
|
||||||
|
memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
||||||
|
print(f"CUDA allocated: {memory_allocated}GB")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
estimator_name: str,
|
||||||
|
dataset_name: str,
|
||||||
|
dataloader_name: str,
|
||||||
|
estimator_kwargs: dict[str, Any] | None = None,
|
||||||
|
dataset_kwargs: dict[str, Any] | None = None,
|
||||||
|
dataset_split_fracs: dict[str, Any] | None = None,
|
||||||
|
dataset_split_kwargs: dict[str, Any] | None = None,
|
||||||
|
dataloader_kwargs: dict[str, Any] | None = None,
|
||||||
|
trainer_kwargs: dict[str, Any] | None = None,
|
||||||
|
train_kwargs: dict[str, Any] | None = None,
|
||||||
|
load_only: bool = False,
|
||||||
|
) -> Trainer | Estimator:
|
||||||
|
|
||||||
|
try:
|
||||||
|
estimator_cls = estimator_map[estimator_name]
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid estimator name '{estimator_name}',"
|
||||||
|
f"must be one of {estimator_map.keys()}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset_cls = dataset_map[dataset_name]
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid dataset name '{dataset_name}',"
|
||||||
|
f"must be one of {dataset_map.keys()}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataloader_cls = dataloader_map[dataloader_name]
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid dataloader name '{dataloader_name}',"
|
||||||
|
f"must be one of {dataloader_map.keys()}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
estimator_kwargs = estimator_kwargs or {}
|
||||||
|
dataset_kwargs = dataset_kwargs or {}
|
||||||
|
dataset_split_fracs = dataset_split_fracs or {}
|
||||||
|
dataset_split_kwargs = dataset_split_kwargs or {}
|
||||||
|
dataloader_kwargs = dataloader_kwargs or {}
|
||||||
|
trainer_kwargs = trainer_kwargs or {}
|
||||||
|
train_kwargs = train_kwargs or {}
|
||||||
|
|
||||||
|
default_estimator_kwargs = {}
|
||||||
|
default_dataset_kwargs = {}
|
||||||
|
default_dataset_split_kwargs = {}
|
||||||
|
default_dataset_split_fracs = {"train": 1.0, "val": 0.0, "aux": []}
|
||||||
|
default_dataloader_kwargs = {}
|
||||||
|
default_trainer_kwargs = {}
|
||||||
|
default_train_kwargs = {}
|
||||||
|
|
||||||
|
estimator_kwargs = {**default_estimator_kwargs, **estimator_kwargs}
|
||||||
|
dataset_kwargs = {**default_dataset_kwargs, **dataset_kwargs}
|
||||||
|
dataset_split_kwargs = {**default_dataset_split_kwargs, **dataset_split_kwargs}
|
||||||
|
dataset_split_fracs = {**default_dataset_split_fracs, **dataset_split_fracs}
|
||||||
|
dataloader_kwargs = {**default_dataloader_kwargs, **dataloader_kwargs}
|
||||||
|
trainer_kwargs = {**default_trainer_kwargs, **trainer_kwargs}
|
||||||
|
train_kwargs = {**default_train_kwargs, **train_kwargs}
|
||||||
|
|
||||||
|
estimator = estimator_cls(**estimator_kwargs)
|
||||||
|
dataset = dataset_cls(**dataset_kwargs)
|
||||||
|
train_dataset, val_dataset, *aux_datasets = dataset.split(
|
||||||
|
fracs=[
|
||||||
|
dataset_split_fracs["train"],
|
||||||
|
dataset_split_fracs["val"],
|
||||||
|
*dataset_split_fracs["aux"]
|
||||||
|
],
|
||||||
|
**dataset_split_kwargs
|
||||||
|
)
|
||||||
|
train_loader = dataloader_cls(train_dataset, **dataloader_kwargs)
|
||||||
|
val_loader = dataloader_cls(val_dataset, **dataloader_kwargs)
|
||||||
|
aux_loaders = [
|
||||||
|
dataloader_cls(aux_dataset, **dataloader_kwargs)
|
||||||
|
for aux_dataset in aux_datasets
|
||||||
|
]
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
estimator,
|
||||||
|
**trainer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_only:
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
return trainer.train(
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
aux_loaders=aux_loaders,
|
||||||
|
**train_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_from_json(
|
||||||
|
parameters_json: str | None = None,
|
||||||
|
parameters_file: str | None = None,
|
||||||
|
) -> Trainer | Estimator:
|
||||||
|
if not (parameters_json or parameters_file):
|
||||||
|
raise ValueError("parameter json or file required")
|
||||||
|
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
if parameters_json:
|
||||||
|
try:
|
||||||
|
parameters = json.loads(parameters_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON format: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading JSON parameters: {e}") from e
|
||||||
|
|
||||||
|
elif parameters_file:
|
||||||
|
try:
|
||||||
|
with open(parameters_file, encoding="utf-8") as f:
|
||||||
|
parameters = json.load(f)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise ValueError(f"JSON file not found: {parameters_file}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error loading JSON parameters: {e}") from e
|
||||||
|
|
||||||
|
return run(**parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_train(args: argparse.Namespace) -> None:
|
||||||
|
run_from_json(args.parameters_json, args.parameters_file)
|
||||||
|
|
||||||
|
|
||||||
|
def register_parser(subparsers: _SubParsersAction) -> None:
|
||||||
|
parser = subparsers.add_parser("train", help="run training loop")
|
||||||
|
parser.add_argument(
|
||||||
|
"--parameters-json",
|
||||||
|
type=str,
|
||||||
|
help="Raw JSON string with train parameters",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--parameters-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to JSON file with train parameters",
|
||||||
|
)
|
||||||
|
parser.set_defaults(func=handle_train)
|
||||||
12
trainlib/dataloaders/__init__.py
Normal file
12
trainlib/dataloaders/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from trainlib.dataloader import EstimatorDataLoader
|
||||||
|
from trainlib.utils.text import camel_to_snake
|
||||||
|
from trainlib.dataloaders.memory import SupervisedDataLoader
|
||||||
|
|
||||||
|
_dataloaders = [
|
||||||
|
SupervisedDataLoader,
|
||||||
|
]
|
||||||
|
dataloader_map: dict[str, type[EstimatorDataLoader]] = {
|
||||||
|
camel_to_snake(_dataloader.__name__): _dataloader
|
||||||
|
for _dataloader in _dataloaders
|
||||||
|
}
|
||||||
|
|
||||||
17
trainlib/dataloaders/memory.py
Normal file
17
trainlib/dataloaders/memory.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from trainlib.estimator import SupervisedKwargs
|
||||||
|
from trainlib.dataloader import EstimatorDataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedDataLoader(
|
||||||
|
EstimatorDataLoader[tuple[Tensor, Tensor], SupervisedKwargs]
|
||||||
|
):
|
||||||
|
def batch_to_est_kwargs(
|
||||||
|
self,
|
||||||
|
batch_data: tuple[Tensor, Tensor]
|
||||||
|
) -> SupervisedKwargs:
|
||||||
|
return SupervisedKwargs(
|
||||||
|
inputs=batch_data[0],
|
||||||
|
labels=batch_data[1],
|
||||||
|
)
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
from trainlib.dataset import BatchedDataset
|
||||||
|
from trainlib.utils.text import camel_to_snake
|
||||||
|
from trainlib.datasets.memory import RandomXYDataset
|
||||||
|
|
||||||
|
_datasets = [
|
||||||
|
RandomXYDataset,
|
||||||
|
]
|
||||||
|
dataset_map: dict[str, type[BatchedDataset]] = {
|
||||||
|
camel_to_snake(_dataset.__name__): _dataset
|
||||||
|
for _dataset in _datasets
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -73,13 +73,33 @@ class RecordDataset[T: NamedTuple](HomogenousDataset[int, T, T]):
|
|||||||
|
|
||||||
from typing import Unpack
|
from typing import Unpack
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.utils.data import TensorDataset
|
||||||
|
|
||||||
from trainlib.domain import SequenceDomain
|
from trainlib.domain import SequenceDomain
|
||||||
from trainlib.dataset import TupleDataset, DatasetKwargs
|
from trainlib.dataset import TupleDataset, DatasetKwargs
|
||||||
|
|
||||||
|
|
||||||
|
class RandomXYDataset(TupleDataset[Tensor]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_samples: int,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
**kwargs: Unpack[DatasetKwargs],
|
||||||
|
) -> None:
|
||||||
|
domain = SequenceDomain[tuple[Tensor, Tensor]](
|
||||||
|
TensorDataset(
|
||||||
|
torch.randn((num_samples, input_dim)),
|
||||||
|
torch.randn((num_samples, output_dim))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(domain, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class SlidingWindowDataset(TupleDataset[Tensor]):
|
class SlidingWindowDataset(TupleDataset[Tensor]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ class EstimatorKwargs(TypedDict):
|
|||||||
inputs: Tensor
|
inputs: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SupervisedKwargs(EstimatorKwargs):
|
||||||
|
labels: Tensor
|
||||||
|
|
||||||
|
|
||||||
class Estimator[Kw: EstimatorKwargs](nn.Module):
|
class Estimator[Kw: EstimatorKwargs](nn.Module):
|
||||||
"""
|
"""
|
||||||
Estimator base class.
|
Estimator base class.
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from trainlib.estimator import Estimator
|
||||||
|
from trainlib.utils.text import camel_to_snake
|
||||||
|
from trainlib.estimators.mlp import MLP
|
||||||
|
from trainlib.estimators.rnn import LSTM, ConvGRU, MultiheadLSTM
|
||||||
|
|
||||||
|
_estimators: list[type[Estimator]] = [
|
||||||
|
MLP,
|
||||||
|
LSTM,
|
||||||
|
MultiheadLSTM,
|
||||||
|
ConvGRU,
|
||||||
|
]
|
||||||
|
estimator_map: dict[str, type[Estimator]] = {
|
||||||
|
camel_to_snake(_estimator.__name__): _estimator
|
||||||
|
for _estimator in _estimators
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from trainlib.estimator import Estimator, EstimatorKwargs
|
|||||||
from trainlib.utils.map import nested_defaultdict
|
from trainlib.utils.map import nested_defaultdict
|
||||||
from trainlib.dataloader import EstimatorDataLoader
|
from trainlib.dataloader import EstimatorDataLoader
|
||||||
from trainlib.utils.module import ModelWrapper
|
from trainlib.utils.module import ModelWrapper
|
||||||
|
from trainlib.utils.session import ensure_same_device
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -121,6 +122,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
loader: EstimatorDataLoader[Any, Kw],
|
loader: EstimatorDataLoader[Any, Kw],
|
||||||
optimizers: tuple[Optimizer, ...],
|
optimizers: tuple[Optimizer, ...],
|
||||||
max_grad_norm: float | None = None,
|
max_grad_norm: float | None = None,
|
||||||
|
progress_bar: tqdm | None = None,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Train the estimator for a single epoch.
|
Train the estimator for a single epoch.
|
||||||
@@ -128,8 +130,8 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
|
|
||||||
loss_sums = []
|
loss_sums = []
|
||||||
self.estimator.train()
|
self.estimator.train()
|
||||||
with tqdm(loader, unit="batch") as batches:
|
for i, batch_kwargs in enumerate(loader):
|
||||||
for i, batch_kwargs in enumerate(batches):
|
batch_kwargs = ensure_same_device(batch_kwargs, self.device)
|
||||||
losses = self.estimator.loss(**batch_kwargs)
|
losses = self.estimator.loss(**batch_kwargs)
|
||||||
|
|
||||||
for o_idx, (loss, optimizer) in enumerate(
|
for o_idx, (loss, optimizer) in enumerate(
|
||||||
@@ -153,7 +155,15 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
|
|
||||||
# set loop loss to running average (reducing if multi-loss)
|
# set loop loss to running average (reducing if multi-loss)
|
||||||
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
||||||
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
|
||||||
|
if progress_bar:
|
||||||
|
progress_bar.update(1)
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
epoch=self._epoch,
|
||||||
|
mode="opt",
|
||||||
|
data="train",
|
||||||
|
loss=f"{loss_avg:8.2f}",
|
||||||
|
)
|
||||||
|
|
||||||
# step estimator hyperparam schedules
|
# step estimator hyperparam schedules
|
||||||
self.estimator.epoch_step()
|
self.estimator.epoch_step()
|
||||||
@@ -163,7 +173,8 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
def _eval_epoch(
|
def _eval_epoch(
|
||||||
self,
|
self,
|
||||||
loader: EstimatorDataLoader[Any, Kw],
|
loader: EstimatorDataLoader[Any, Kw],
|
||||||
label: str
|
label: str,
|
||||||
|
progress_bar: tqdm | None = None,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""
|
"""
|
||||||
Perform and record validation scores for a single epoch.
|
Perform and record validation scores for a single epoch.
|
||||||
@@ -191,12 +202,12 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
|
|
||||||
loss_sums = []
|
loss_sums = []
|
||||||
self.estimator.eval()
|
self.estimator.eval()
|
||||||
with tqdm(loader, unit="batch") as batches:
|
for i, batch_kwargs in enumerate(loader):
|
||||||
for i, batch_kwargs in enumerate(batches):
|
batch_kwargs = ensure_same_device(batch_kwargs, self.device)
|
||||||
losses = self.estimator.loss(**batch_kwargs)
|
losses = self.estimator.loss(**batch_kwargs)
|
||||||
|
|
||||||
# one-time logging
|
# once-per-session logging
|
||||||
if self._epoch == 0:
|
if self._epoch == 0 and i == 0:
|
||||||
self._writer.add_graph(
|
self._writer.add_graph(
|
||||||
ModelWrapper(self.estimator), batch_kwargs
|
ModelWrapper(self.estimator), batch_kwargs
|
||||||
)
|
)
|
||||||
@@ -220,7 +231,15 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
|
|
||||||
# set loop loss to running average (reducing if multi-loss)
|
# set loop loss to running average (reducing if multi-loss)
|
||||||
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
loss_avg = sum(loss_sums) / (len(loss_sums)*(i+1))
|
||||||
batches.set_postfix(loss=f"{loss_avg:8.2f}")
|
|
||||||
|
if progress_bar:
|
||||||
|
progress_bar.update(1)
|
||||||
|
progress_bar.set_postfix(
|
||||||
|
epoch=self._epoch,
|
||||||
|
mode="eval",
|
||||||
|
data=label,
|
||||||
|
loss=f"{loss_avg:8.2f}",
|
||||||
|
)
|
||||||
|
|
||||||
# log individual loss terms after each batch
|
# log individual loss terms after each batch
|
||||||
for o_idx, loss_item in enumerate(loss_items):
|
for o_idx, loss_item in enumerate(loss_items):
|
||||||
@@ -240,6 +259,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
train_loader: EstimatorDataLoader[Any, Kw],
|
train_loader: EstimatorDataLoader[Any, Kw],
|
||||||
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
|
val_loader: EstimatorDataLoader[Any, Kw] | None = None,
|
||||||
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
|
aux_loaders: list[EstimatorDataLoader[Any, Kw]] | None = None,
|
||||||
|
progress_bar: tqdm | None = None,
|
||||||
) -> tuple[list[float], list[float] | None, *list[float]]:
|
) -> tuple[list[float], list[float] | None, *list[float]]:
|
||||||
"""
|
"""
|
||||||
Evaluate estimator over each provided dataloader.
|
Evaluate estimator over each provided dataloader.
|
||||||
@@ -274,12 +294,15 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
somewhere given the many possible design choices here.)
|
somewhere given the many possible design choices here.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
train_loss = self._eval_epoch(train_loader, "train")
|
train_loss = self._eval_epoch(train_loader, "train", progress_bar)
|
||||||
val_loss = self._eval_epoch(val_loader, "val") if val_loader else None
|
|
||||||
|
val_loss = None
|
||||||
|
if val_loader is not None:
|
||||||
|
val_loss = self._eval_epoch(val_loader, "val", progress_bar)
|
||||||
|
|
||||||
aux_loaders = aux_loaders or []
|
aux_loaders = aux_loaders or []
|
||||||
aux_losses = [
|
aux_losses = [
|
||||||
self._eval_epoch(aux_loader, f"aux{i}")
|
self._eval_epoch(aux_loader, f"aux{i}", progress_bar)
|
||||||
for i, aux_loader in enumerate(aux_loaders)
|
for i, aux_loader in enumerate(aux_loaders)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -433,9 +456,10 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
self._session_name = session_name or str(int(time.time()))
|
self._session_name = session_name or str(int(time.time()))
|
||||||
tblog_path = Path(self.tblog_dir, self._session_name)
|
tblog_path = Path(self.tblog_dir, self._session_name)
|
||||||
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
self._writer = summary_writer or SummaryWriter(f"{tblog_path}")
|
||||||
|
progress_bar = tqdm(train_loader, unit="batch")
|
||||||
|
|
||||||
# evaluate model on dataloaders once before training starts
|
# evaluate model on dataloaders once before training starts
|
||||||
self._eval_loaders(train_loader, val_loader, aux_loaders)
|
self._eval_loaders(train_loader, val_loader, aux_loaders, progress_bar)
|
||||||
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
optimizers = self.estimator.optimizers(lr=lr, eps=eps)
|
||||||
|
|
||||||
while self._epoch < max_epochs and not self._converged(
|
while self._epoch < max_epochs and not self._converged(
|
||||||
@@ -444,16 +468,21 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
self._epoch += 1
|
self._epoch += 1
|
||||||
train_frac = f"{self._epoch}/{max_epochs}"
|
train_frac = f"{self._epoch}/{max_epochs}"
|
||||||
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
stag_frac = f"{self._stagnant_epochs}/{stop_after_epochs}"
|
||||||
print(f"Training epoch {train_frac}...")
|
#print(f"Training epoch {train_frac}...")
|
||||||
print(f"Stagnant epochs {stag_frac}...")
|
#print(f"Stagnant epochs {stag_frac}...")
|
||||||
|
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
self._train_epoch(train_loader, optimizers, max_grad_norm)
|
self._train_epoch(
|
||||||
|
train_loader,
|
||||||
|
optimizers,
|
||||||
|
max_grad_norm,
|
||||||
|
progress_bar=progress_bar
|
||||||
|
)
|
||||||
epoch_end_time = time.time() - epoch_start_time
|
epoch_end_time = time.time() - epoch_start_time
|
||||||
self._log_event("train", "epoch_duration", epoch_end_time)
|
self._log_event("train", "epoch_duration", epoch_end_time)
|
||||||
|
|
||||||
train_loss, val_loss, _ = self._eval_loaders(
|
train_loss, val_loss, *_ = self._eval_loaders(
|
||||||
train_loader, val_loader, aux_loaders
|
train_loader, val_loader, aux_loaders, progress_bar
|
||||||
)
|
)
|
||||||
# determine loss to use for measuring convergence
|
# determine loss to use for measuring convergence
|
||||||
conv_loss = val_loss if val_loss else train_loss
|
conv_loss = val_loss if val_loss else train_loss
|
||||||
@@ -491,6 +520,7 @@ class Trainer[I, Kw: EstimatorKwargs]:
|
|||||||
print(f"==== Epoch [{self._epoch}] summary ====")
|
print(f"==== Epoch [{self._epoch}] summary ====")
|
||||||
for (group, name), epoch_map in self._summary.items():
|
for (group, name), epoch_map in self._summary.items():
|
||||||
for epoch, values in epoch_map.items():
|
for epoch, values in epoch_map.items():
|
||||||
|
# compute average over batch items recorded for the epoch
|
||||||
mean = torch.tensor(values).mean().item()
|
mean = torch.tensor(values).mean().item()
|
||||||
self._writer.add_scalar(f"{group}-{name}", mean, epoch)
|
self._writer.add_scalar(f"{group}-{name}", mean, epoch)
|
||||||
if epoch == self._epoch:
|
if epoch == self._epoch:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import random
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.utils import _pytree as pytree
|
||||||
|
|
||||||
|
|
||||||
def seed_all_backends(seed: int | Tensor | None = None) -> None:
|
def seed_all_backends(seed: int | Tensor | None = None) -> None:
|
||||||
@@ -19,3 +20,9 @@ def seed_all_backends(seed: int | Tensor | None = None) -> None:
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
def ensure_same_device[T](tree: T, device: str) -> T:
|
||||||
|
return pytree.tree_map(
|
||||||
|
lambda x: x.to(device) if isinstance(x, torch.Tensor) else x,
|
||||||
|
tree,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from colorama import Style
|
from colorama import Style
|
||||||
|
|
||||||
|
camel2snake_regex: re.Pattern[str] = re.compile(
|
||||||
|
r"(?<!^)(?=[A-Z][a-z])|(?<=[a-z])(?=[A-Z])"
|
||||||
|
)
|
||||||
|
|
||||||
|
def camel_to_snake(text: str) -> str:
|
||||||
|
return camel2snake_regex.sub("_", text).lower()
|
||||||
|
|
||||||
def color_text(text: str, *colorama_args: Any) -> str:
|
def color_text(text: str, *colorama_args: Any) -> str:
|
||||||
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"
|
return f"{''.join(colorama_args)}{text}{Style.RESET_ALL}"
|
||||||
|
|
||||||
|
|||||||
@@ -61,8 +61,8 @@ class SubplotsKwargs(TypedDict, total=False):
|
|||||||
squeeze: bool
|
squeeze: bool
|
||||||
width_ratios: Sequence[float]
|
width_ratios: Sequence[float]
|
||||||
height_ratios: Sequence[float]
|
height_ratios: Sequence[float]
|
||||||
subplot_kw: dict[str, ...]
|
subplot_kw: dict[str, Any]
|
||||||
gridspec_kw: dict[str, ...]
|
gridspec_kw: dict[str, Any]
|
||||||
figsize: tuple[float, float]
|
figsize: tuple[float, float]
|
||||||
dpi: float
|
dpi: float
|
||||||
layout: str
|
layout: str
|
||||||
|
|||||||
Reference in New Issue
Block a user