refactor Trainer train/val loop
This commit is contained in:
20
example/dataset.py
Normal file
20
example/dataset.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from trainlib.domain import SequenceDomain
|
||||
from trainlib.datasets.memory import TupleDataset
|
||||
|
||||
|
||||
class Record(NamedTuple):
|
||||
a: int
|
||||
b: str
|
||||
|
||||
tl_domain = SequenceDomain[Record]([
|
||||
Record(1, "1"),
|
||||
Record(2, "2"),
|
||||
])
|
||||
|
||||
class R0(TupleDataset[Record]):
|
||||
item_tuple = Record
|
||||
|
||||
def _process_item_data(self, item_data, item_index):
|
||||
return (item_data[0],)
|
||||
Reference in New Issue
Block a user