Files
trainlib/trainlib/dataset.py

1028 lines
38 KiB
Python

"""
Domain-generic dataset base with attribute-based splitting and balancing
**Marginalizing out the modality layer**
With ``domain`` being an instance variable, one possible interpretation of
the object structures here is that one could completely abstract away
the domain model, defining only item structures and processing data. You
could have a single dataset definition for a particular concrete dataset,
and so long as we're talking about the same items, it can be instantiated
using *any domain*. You wouldn't need specific subclasses for disk or
network or in-memory structures; you can tell it directly at runtime.
That's an eventually possibility, anyway. As it stands, however, this is
effectively impossible:
You can't easily abstract the batch-to-item splitting process, i.e.,
``_process_batch_data()``. A list-based version of the dataset you're trying to
define might have an individual item tuple at every index, whereas a disk-based
version might have tuples batched across a few files. This can't reliably be
inferred, nor can it be pushed to the ``Domain``-level without needing equal
levels of specialization (you'd just end up needing the exact same structural
distinctions in the ``Domain`` hierarchy). So *somewhere* you need a batch
splitting implementation that is both item structure-dependent *and*
domain-dependent...the question is how dynamic you're willing to be about where
it comes from. Right now, we require this actually be defined in the
``_process_batch_data()`` method, meaning you'll need a specific ``Dataset``
class for each domain you want to support (e.g., ``MNISTDisk``, ``MNISTList``,
``MNISTNetwork``, etc), or at least for each domain where "interpreting" a
batch could possibly differ. This is a case where the interface is all that
enforces a distinction: if you've got two domains that can be counted on to
yield batches in the exact same way and can use the same processing, then you
could feasibly provide ``Domain`` objects from either at runtime and have no
issues. We're "structurally blind" to any differentiation beyond the URI and
resource types by design, so two different domain implementations with the same
type signature ``Domain[U, R]`` should be expected to work fine at runtime
(again, so long as they don't also need different batch processing), but that's
not affording us much flexibility, i.e., most of the time we'll still be
defining new dataset classes for each domain.
I initially flagged this as feasible, however, because one could imagine
accepting a batch processing method upon instantiation rather than structurally
bolting it into the ``Dataset`` definition. This would require knowledge of the
item structure ``I`` as well as the ``Domain[U, R]``, so such a function will
always have to be ``(I, U, R)``-dependent. It nevertheless would take out some
of the pain of having to define new dataset classes; instead, you'd just need
to define the batch processing method. I see this as a worse alternative to
just defining *inside* a safe context like a new dataset class: you know the
types you have to respect, and you stick that method exactly in a context where
it's understood. Freeing this up doesn't lighten the burden of processing
logic, it just changes *when* it has to be provided, and that's not worth much
(to me) in this case given the bump in complexity. (Taking this to the extreme:
you could supply *all* of an object's methods "dynamically" and glue them
together at runtime so long as they all played nice. But wherever you were
"laying them out" beforehand is exactly the job of a class to begin with, so
you don't end up with anything more dynamic. All we're really discussing here
is pushing around unavoidable complexity inside and outside of the "class
walls," and in the particular case of ``_process_batch_data()``, it feels much
better when it's on the inside.)
"""
import math
import random
import logging
from abc import abstractmethod
from copy import copy
from bisect import bisect
from typing import Unpack, TypedDict
from functools import lru_cache
from collections import defaultdict
from collections.abc import Callable, Iterator
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import Dataset
from trainlib.utils import job
from trainlib.domain import Domain, SequenceDomain
from trainlib.transform import Transform
logger: logging.Logger = logging.getLogger(__name__)
class DatasetKwargs[I](TypedDict, total=False):
pre_transform: Transform[I]
post_transform: Transform[I]
batch_cache_limit: int
preload: bool
num_workers: int
class BatchedDataset[U, R, I](Dataset):
"""
Generic dataset that dynamically pulls batched data from resources (e.g.,
files, remote locations, streams, etc).
The class is generic over a URI type ``U``, a resource type ``R`` (both of
which are used to concretize a domain ``Domain[U, R]``), and an item type
``I``.
**Batch and item processing flow**
.. code-block:: text
Domain -> [U] :: self._batch_uris = list(domain)
Grab all URIs from Domain iterators. This is made concrete early to
allow for Dataset sizing, and we need a Sequence representation to
map integer batch indices into Domains, i.e., when getting the
corresponding URI:
batch_uri = self._batch_uris[batch_index]
We let Domains implement iterators over their URIs, but explicitly
exhaust when initializing Datasets.
U -> R :: batch_data = self.domain[batch_uri]
Retrieve resource from domain. Resources are viewed as batched
data, even if only wrapping single items (happens in trivial
settings).
R -> [I] :: self._process_batch_data(batch_data, batch_index)
Possibly domain-specific batch processing of resource data into
explicit Sequence-like structures of items, each of which is
subject to the provided pre_transform. Processed batches at this
stage are cached (if enabled).
[I] -> I :: self.get_batch(batch_index)[index_in_batch]
Select individual items from batches in _get_item. At this stage,
items are in intermediate states and pulled from the cached
batches.
I -> I :: self._process_item_data(item_data, index)
Produce final items with __getitem__, getting intermediate items
via _get_item and applying the provided post_transform.
.. note::
As far as positioning, this class is meant to play nice with PyTorch
DataLoaders, hence the inheritance from ``torch.Dataset``. The value
add for this over the ``torch.Dataset`` base is almost entirely in the
logic it implements to map out of *batched resources* that are holding
data, and flattening it out into typical dataset items. There are also
some QoL features when it comes to splitting and balancing samples.
.. note::
Even though ``Domains`` implement iterators over their URIs, this
doesn't imply a ``BatchedDataset`` is iterable. This just means we can
walk over the resources that provide data, but we don't necessarily
presuppose an ordered walk over samples within batches. Point being:
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
superclass, even when we're working around iterable ``Domains``.
.. note::
Transforms are expected to operate on ``I``-items and produce
``I``-items. They shouldn't be the "introducers" of ``I`` types from
some other intermediate representation, nor should they map from ``I``
to something else. Point being: the dataset definition should be able
to map resources ``R`` to ``I`` without a transform: that much should
be baked into the class definition. If you find you're expecting the
transform to do that for you, you should consider pulling in some
common structure across the allowed transforms and make it a fixed part
of the class.
"""
def __init__(
self,
domain: Domain[U, R],
pre_transform: Transform[I] | None = None,
post_transform: Transform[I] | None = None,
batch_cache_limit: int | None = None,
preload: bool = False,
num_workers: int = 1,
) -> None:
"""
Parameters:
domain: ``Domain`` object providing access to batched data
pre_transform: transform to apply over items during loading (in
``_process_batch_data()``), i.e., *before* going into
persistent storage
post_transform: transform to apply just prior to returning an item
(in ``_process_item_data()``), i.e., only *after* retrieval
from persistent storage
batch_cache_limit: the max number of max batches to cache at any
one time
preload: whether to load all data into memory during instantiation
num_workers: number of workers to use when preloading data
"""
self.domain = domain
self.pre_transform = pre_transform
self.post_transform = post_transform
self.batch_cache_limit = batch_cache_limit
self.num_workers = num_workers
logger.info("Fetching URIs...")
self._batch_uris: list[U] = list(domain)
self._indices: list[int] | None = None
self._dataset_len: int | None = None
self._num_batches: int = len(domain)
self.get_batch: Callable[[int], list[I]] = lru_cache(
maxsize=batch_cache_limit
)(self._get_batch)
if preload:
self.load_all(num_workers=num_workers)
@abstractmethod
def _get_dataset_len(self) -> int:
"""
Calculate the total dataset size in units of samples (not batches).
"""
raise NotImplementedError
@abstractmethod
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
"""
Return the index of the batch containing the item at the provided item
index, and the index of the item within that batch.
The behavior of this method can vary depending on what we know about
batch sizes, and should therefore be implemented by inheriting classes.
Parameters:
item_index: index of item
Returns:
batch_index: int
index_in_batch: int
"""
raise NotImplementedError
@abstractmethod
def _process_batch_data(
self,
batch_data: R,
batch_index: int,
) -> list[I]:
"""
Process raw domain resource data (e.g., parse JSON or load tensors) and
split accordingly, such that the returned batch is a collection of
``I`` items.
If an inheriting class wants to allow dynamic transforms, this is the
place to use a provided ``pre_transform``; the collection of items
produced by this method are cached as a batch, so results from such a
transform will be stored.
Parameters:
batch_data: tuple of resource data
batch_index: index of batch
"""
raise NotImplementedError
@abstractmethod
def _process_item_data(
self,
item_data: I,
item_index: int,
) -> I:
"""
Process individual items and produce final tuples.
If an inheriting class wants to allow dynamic transforms, this is the
place to use a provided ``post_transform``; items are pulled from the
cache (if enabled) and processed before being returned as the final
tuple outputs (so this processing is not persistent).
Parameters:
item_data: item data
item_index: index of item
"""
raise NotImplementedError
def _get_item(self, item_index: int) -> I:
"""
Get the item data and zip with the item header.
Items should be the most granular representation of dataset samples
with maximum detail (i.e., yield a all available information). An
iterator over these representations across all samples can be retrieved
with `.items()`.
Note that return values from `__getitem__()` are "cleaned up" versions
of this representation, with minimal info needed for training.
Parameters:
item_index: index of item
"""
if item_index >= len(self):
raise IndexError
# alt indices redefine index count
item_index = self.indices[item_index]
batch_index, index_in_batch = self._get_batch_for_item(item_index)
return self.get_batch(batch_index)[index_in_batch]
def _get_batch(self, batch_index: int) -> list[I]:
"""
Return the batch data for the provided index.
Note that we require a list return type. This is where the rubber meets
the road in terms of expanding batches: the outputs here get cached, if
caching is enabled. If we were to defer batch expansion to some later
stage, that caching will be more or less worthless. For instance, if
``._process_batch_data()`` was an iterator, at best the batch
processing logic could be delayed, but then you'd either 1) cache the
iterator reference, or 2) have to further delay caching until
post-batch processing. Further, ``._process_batch_data()`` can (and
often does) depend on the entire batch, so you can't handle that
item-wise: you've got to pass all batch data in and can't act on
slices, so doing this would be irrelevant anyway.
How about iterators out from ``_read_resources()``? Then your
``_process_batch_data()`` can iterate as needed and do the work later?
The point here is that this distinction is irrelevant because you're
reading resources and processing the data here in the same method:
they're always connected, and nothing would notice if you waited
between steps. The only way this could matter is if you split the
resource reading and batch processing steps across methods, but when it
actually comes to accessing/caching the batch, you'd have to expand any
delayed reads here. There's no way around needing to see all batch data
at once here, and we don't want to make that ambiguous: ``list`` output
type it is.
Parameters:
batch_index: index of batch
"""
logger.debug("Batch cache miss, reading from root...")
if batch_index >= self._num_batches:
raise IndexError
batch_uri = self._batch_uris[batch_index]
batch_data = self.domain[batch_uri]
return self._process_batch_data(batch_data, batch_index)
def load_all(self, num_workers: int | None = None) -> list[list[I]]:
"""
Preload all data batches into the cache.
Can be useful when dynamically pulling data (as it's requested) isn't
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
continually remove previous batches as they're loaded.
Parameters:
num_workers: number of parallel workers to use for data loading
"""
assert self.batch_cache_limit is None, "Preloading under cache limit"
if num_workers is None:
num_workers = self.num_workers
thread_pool = ThreadPoolExecutor(max_workers=num_workers)
futures = []
for batch_index in range(self._num_batches):
future = thread_pool.submit(
self.get_batch,
batch_index,
)
futures.append(future)
job.process_futures(futures, "Loading dataset batches", "batch")
thread_pool.shutdown(wait=True)
return [future.result() for future in futures]
def split(
self,
fracs: list[float],
dataset: "BatchedDataset | None" = None,
by_attr: str | list[str | None] | None = None,
shuffle_strata: bool = True,
) -> list["BatchedDataset"]:
"""
Split dataset into fractional pieces by data attribute.
If ``by_attr`` is None, recovers typical fractional splitting of
dataset items, partitioning by size. Using None anywhere will index
each item into its own bucket, i.e., by its index. For instance:
- Splits on the attribute such that each subset contains entire strata
of the attribute. "Homogeneity within clusters:"
.. code-block::
by_attr=["color"] -> {("red", 1), ("red", 2)},
{("blue", 1), ("blue", 2)}
- Stratifies by attribute and then splits "by index" within, uniformly
grabbing samples across strata to form new clusters. "Homogeneity
across clusters"
.. code-block::
by_attr=["color", None] -> {("red", 1), ("blue", 1)},
{("red", 2), ("blue", 2)}
Note that the final list of Subsets returned are built from shallow
copies of the underlying dataset (i.e., ``self``) to allow manual
intervention with dataset attributes (e.g., setting the splits to have
different ``transforms``). This is subject to possibly unexpected
behavior if re-caching data or you need a true copy of all data in
memory, but should otherwise leave most interactions unchanged.
Parameters:
frac: split fractions for datasets
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of the strata order can change the
actual content of the splits themselves.
"""
if by_attr == []:
raise ValueError("Cannot parse empty value list")
assert (
math.isclose(sum(fracs), 1) and sum(fracs) <= 1
), "Fractions do not sum to 1"
if isinstance(by_attr, str) or by_attr is None:
by_attr = [by_attr]
if dataset is None:
dataset = self
# dataset = DictDataset([
# self._get_item(i) for i in range(len(self))
# ])
# group samples by specified attr
attr_dict = defaultdict(list)
attr_key, by_attr = by_attr[0], by_attr[1:]
# for i in range(len(dataset)):
for i, item in enumerate(dataset.items()):
# item = dataset[i]
if attr_key is None:
attr_val = i
elif attr_key in item:
attr_val = item[attr_key]
else:
raise IndexError(f"Attribute {attr_key} not in dataset item")
attr_dict[attr_val].append(i)
if by_attr == []:
attr_keys = list(attr_dict.keys())
# shuffle keys; randomized group-level split
if shuffle_strata:
random.shuffle(attr_keys)
# considering: defer to dataloader shuffle param; should have same
# effect shuffle values; has no impact on where the split is drawn
# for attr_vals in attr_dict.values():
# random.shuffle(attr_vals)
# fractionally split over attribute keys
offset, splits = 0, []
for frac in fracs[:-1]:
frac_indices = []
frac_size = int(frac * len(attr_keys))
for j in range(offset, offset + frac_size):
j_indices = attr_dict.pop(attr_keys[j])
frac_indices.extend(j_indices)
offset += frac_size
splits.append(frac_indices)
rem_indices = []
for r_indices in attr_dict.values():
rem_indices.extend(r_indices)
splits.append(rem_indices)
subsets = []
for split in splits:
subset = copy(dataset)
subset.indices = split
subsets.append(subset)
return subsets
else:
splits = [[] for _ in range(len(fracs))]
for index_split in attr_dict.values():
# subset = Subset(dataset, index_split)
subset = copy(dataset)
subset.indices = index_split
subset_splits = self.split(
fracs, subset, by_attr, shuffle_strata
)
# unpack stratified splits
for i, subset_split in enumerate(subset_splits):
# splits[i].extend([
# index_split[s] for s in subset_split.indices
# ])
splits[i].extend(subset_split.indices)
subsets = []
for split in splits:
subset = copy(dataset)
subset.reset_indices()
subset.indices = split
subsets.append(subset)
return subsets
# considering: defer to dataloader shuffle param; should have same
# effect shuffle each split after merging; may otherwise be homogenous
# for split in splits:
# random.shuffle(split)
# return [Subset(copy(self), split) for split in splits]
def balance(
self,
dataset: "BatchedSubset[U, R, I] | None" = None,
by_attr: str | list[str | None] | None = None,
split_min_sizes: list[int] | None = None,
split_max_sizes: list[int] | None = None,
shuffle_strata: bool = True,
) -> None:
"""
Balance the distribution of provided attributes over dataset items.
This method sets the indices over the dataset according to the result
of the rebalancing. The indices are produced by the recursive
``_balance()`` method, which is necessarily separate due to the need
for a contained recursive approach that doesn't change the underlying
dataset during execution.
Parameters:
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
split_min_sizes: minimum allowed sizes of splits. Must have the
same length as ``by_attr``.
split_max_sizes: maximum allowed sizes of splits. Must have the
same length as ``by_attr``.
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of the strata order can change the
actual content of the splits themselves.
"""
self.indices = self._balance(
dataset,
by_attr,
split_min_sizes,
split_max_sizes,
shuffle_strata,
)
def _balance(
self,
dataset: "BatchedSubset[U, R, I] | None" = None,
by_attr: str | list[str | None] | None = None,
split_min_sizes: list[int] | None = None,
split_max_sizes: list[int] | None = None,
shuffle_strata: bool = True,
) -> list[int]:
"""
Recursive balancing of items by attribute.
.. note::
Behavior is a little odd for nested behavior; not exactly perfectly
uniform throughout. This is a little difficult: you can't exactly
know ahead of time the size of the subgroups across splits
Parameters:
dataset: dataset to split, defaults to ``self``. Facilitates
recursive splitting when multi-attribute splits are needed.
by_attr: attribute or attributes to use when grouping strata for
dataset splits. Defaults to ``None``, which will use item
indices.
split_min_sizes: minimum allowed sizes of splits. Must have the
same length as ``by_attr``.
split_max_sizes: maximum allowed sizes of splits. Must have the
same length as ``by_attr``.
shuffle_strata: shuffle the strata order before split is drawn. We
parameterize this because a Dataloader-level shuffle operation
will only change the order of the indices in the resulting
splits; only a shuffle of the strata order can change the
actual content of the splits themselves.
"""
if by_attr == []:
raise ValueError("Cannot parse empty value list")
if isinstance(by_attr, str) or by_attr is None:
by_attr = [by_attr]
if dataset is None:
dataset = BatchedSubset(self, self.indices)
if split_min_sizes == [] or split_min_sizes is None:
split_min_sizes = [0]
if split_max_sizes == [] or split_max_sizes is None:
split_max_sizes = [len(dataset)]
# group samples by specified attr
attr_dict = defaultdict(list)
attr_key, by_attr = by_attr[0], by_attr[1:]
for i in range(len(dataset)):
item = dataset[i]
if attr_key is None:
attr_val = i
elif attr_key in item:
attr_val = getattr(item, attr_key)
else:
raise IndexError(f"Attribute {attr_key} not in dataset item")
attr_dict[attr_val].append(i)
subset_splits = []
split_min_size, split_min_sizes = (
split_min_sizes[0],
split_min_sizes[1:]
)
split_max_size, split_max_sizes = (
split_max_sizes[0],
split_max_sizes[1:]
)
for split_indices in attr_dict.values():
if by_attr != []:
subset_indices = self._balance(
BatchedSubset(dataset, split_indices),
by_attr,
split_min_sizes,
split_max_sizes,
shuffle_strata,
)
split_indices = [split_indices[s] for s in subset_indices]
subset_splits.append(split_indices)
# shuffle splits; randomized group-level split
if shuffle_strata:
random.shuffle(subset_splits)
# note: split_min_size is smallest allowed, min_split_size is smallest
# observed
valid_splits = [
ss for ss in subset_splits
if len(ss) >= split_min_size
]
min_split_size = min(len(split) for split in valid_splits)
subset_indices = []
for split_indices in valid_splits:
# if shuffle_strata:
# random.shuffle(split_indices)
subset_indices.extend(
split_indices[: min(min_split_size, split_max_size)]
)
# print(f"{attr_dict.keys()=}")
# print(f"{[len(s) for s in valid_splits]=}")
# print(f"{min_split_size=}")
return subset_indices
@property
def indices(self) -> list[int]:
if self._indices is None:
self._indices = list(range(len(self)))
return self._indices
@indices.setter
def indices(self, indices: list[int]) -> None:
"""
Note: this logic facilitates nested re-indexing over the same base
dataset. The underlying data remain the same, but when indices get set,
you're effectively applying a mask over any existing indices, always
operating *relative* to the existing mask.
Parameters:
indices: list of indices to set
"""
# manually set new size
self._dataset_len = len(indices)
# note: this is a little tricky and compact; follow what happens when
# _indices aren't already set
self._indices = [
self.indices[index]
for index in indices
]
def reset_indices(self) -> None:
self._indices = None
self._dataset_len = None
def items(self) -> Iterator[I]:
for i in range(len(self)):
yield self._get_item(i)
def __len__(self) -> int:
if self._dataset_len is None:
self._dataset_len = self._get_dataset_len()
return self._dataset_len
def __getitem__(self, index: int) -> I:
"""
Get the dataset item at the specified index.
Parameters:
index: index of item to retrieve
"""
item_data = self._get_item(index)
index = self.indices[index]
return self._process_item_data(item_data, index)
def __iter__(self) -> Iterator[I]:
"""
Note: this method isn't technically needed given ``__getitem__`` is
defined and we operate cleanly over integer indices 0..(N-1), so even
without an explicit ``__iter__``, Python will fall back to a reliable
iteration mechanism. We nevertheless implement the trivial logic below
to convey intent and meet static type checks for iterables.
"""
return (self[i] for i in range(len(self)))
class CompositeBatchedDataset[U, R, I](BatchedDataset[U, R, I]):
"""
Dataset class for wrapping individual datasets.
.. note::
Because this remains a valid ``BatchedDataset``, we re-thread the
generic type variables through the set of composed datasets. That is,
they must have a common domain type ``Domain[U, R]``.
"""
def __init__(
self,
datasets: list[BatchedDataset[U, R, I]],
) -> None:
"""
Parameters:
datasets: list of datasets
"""
self.datasets = datasets
self._indices: list[int] | None = None
self._dataset_len: int | None = None
self._item_psum: list[int] = []
self._batch_psum: list[int] = []
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
if not arr:
return []
prefix_sum = [0] * (len(arr) + 1)
for i in range(len(arr)):
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
return prefix_sum
def _get_dataset_for_item(self, item_index: int) -> tuple[int, int]:
dataset_index = bisect(self._item_psum, item_index) - 1
index_in_dataset = item_index - self._item_psum[dataset_index]
return dataset_index, index_in_dataset
def _get_dataset_for_batch(self, batch_index: int) -> tuple[int, int]:
dataset_index = bisect(self._batch_psum, batch_index) - 1
index_in_dataset = batch_index - self._batch_psum[dataset_index]
return dataset_index, index_in_dataset
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
index_pair = self._get_dataset_for_item(item_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._get_batch_for_item(index_in_dataset)
def _get_dataset_len(self) -> int:
self.load_all()
dataset_batch_sizes = [len(dataset) for dataset in self.datasets]
dataset_sizes = [dataset._num_batches for dataset in self.datasets]
# this method will only be ran once; set this instance var
self._item_psum = self._compute_prefix_sum(dataset_batch_sizes)
self._batch_psum = self._compute_prefix_sum(dataset_sizes)
return self._item_psum[-1]
def _process_batch_data(
self,
batch_data: R,
batch_index: int,
) -> list[I]:
index_pair = self._get_dataset_for_item(batch_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._process_batch_data(batch_data, index_in_dataset)
def _process_item_data(self, item_data: I, item_index: int) -> I:
index_pair = self._get_dataset_for_item(item_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._process_item_data(item_data, index_in_dataset)
def _get_item(self, item_index: int) -> I:
item_index = self.indices[item_index]
index_pair = self._get_dataset_for_item(item_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._get_item(index_in_dataset)
def _get_batch(self, batch_index: int) -> list[I]:
index_pair = self._get_dataset_for_item(batch_index)
dataset_index, index_in_dataset = index_pair
dataset = self.datasets[dataset_index]
return dataset._get_batch(index_in_dataset)
def load_all(
self,
num_workers: int | None = None
) -> list[list[I]]:
batches = []
for dataset in self.datasets:
batches.extend(dataset.load_all(num_workers))
return batches
@property
def pre_transform(self) -> list[Transform[I] | None]:
return [dataset.pre_transform for dataset in self.datasets]
@pre_transform.setter
def pre_transform(self, pre_transform: Transform[I]) -> None:
for dataset in self.datasets:
dataset.pre_transform = pre_transform
@property
def post_transform(self) -> list[Transform[I] | None]:
return [dataset.post_transform for dataset in self.datasets]
@post_transform.setter
def post_transform(self, post_transform: Transform[I]) -> None:
for dataset in self.datasets:
dataset.post_transform = post_transform
class BatchedSubset[U, R, I](BatchedDataset[U, R, I]):
def __init__(
self,
dataset: BatchedDataset[U, R, I],
indices: list[int],
) -> None:
self.dataset = dataset
self.indices = indices
def _get_item(self, item_index: int) -> I:
"""
Subset indices are "reset" in its context. Simply passes through
"""
return self.dataset._get_item(self.indices[item_index])
def _get_dataset_len(self) -> int:
return len(self.indices)
class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
"""
Batched dataset where batches are equally sized.
Subclass from this base when you can count on the reference data being
prepared with fixed size batches (up to the last batch), e.g., that which
has been prepared with a `Packer`. This can greatly improve measurement
time of the dataset size by preventing the need for reading all batch files
upfront, and reduces the cost of identifying item batches from O(log n) to
O(1).
Methods left for inheriting classes:
- ``_process_item_data()``: item processing
- ``_process_batch_data()``: batch processing
"""
def __init__(
self,
domain: Domain[U, R],
**kwargs: Unpack[DatasetKwargs],
) -> None:
super().__init__(domain, **kwargs)
# determine batch size across dataset, along w/ possible partial final
# batch
bsize = rem = len(self.get_batch(self._num_batches - 1))
if self._num_batches > 1:
bsize, rem = len(self.get_batch(self._num_batches - 2)), bsize
self._batch_size: int = bsize
self._batch_rem: int = rem
def _get_dataset_len(self) -> int:
return self._batch_size * (self._num_batches - 1) + self._batch_rem
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
return item_index // self._batch_size, item_index % self._batch_size
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
"""
Batched dataset where batches may have arbitrary size.
Methods left for inheriting classes:
- ``_process_item_data()``: item processing
- ``_process_batch_data()``: batch processing
"""
def __init__(
self,
domain: Domain[U, R],
**kwargs: Unpack[DatasetKwargs],
) -> None:
super().__init__(domain, **kwargs)
self._batch_size_psum: list[int] = []
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
if not arr:
return []
prefix_sum = [0] * (len(arr) + 1)
for i in range(len(arr)):
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
return prefix_sum
def _get_dataset_len(self) -> int:
# type error below: no idea why this is flagged
batches = self.load_all()
batch_sizes = [len(batch) for batch in batches]
# this method will only be ran once; set this instance var
self._batch_size_psum = self._compute_prefix_sum(batch_sizes)
return self._batch_size_psum[-1]
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
batch_index = bisect(self._batch_size_psum, item_index) - 1
index_in_batch = item_index - self._batch_size_psum[batch_index]
return batch_index, index_in_batch
class SequenceDataset[I](HomogenousDataset[int, I, I]):
"""
Trivial dataset skeleton for sequence domains.
``I``-typed sequence items map directly to dataset items. To produce a
fully concrete dataset, one still needs to define ``_process_item_data()``
to map from ``I``-items to tuples.
"""
domain: SequenceDomain[I]
def _process_batch_data(
self,
batch_data: I,
batch_index: int,
) -> list[I]:
if self.pre_transform is not None:
batch_data = self.pre_transform(batch_data)
return [batch_data]
class TupleDataset[T](SequenceDataset[tuple[T, ...]]):
"""
Trivial sequence-of-tuples dataset.
This is the most straightforward line to a concrete dataset from a
``BatchedDataset`` base class. That is: the underlying domain is a sequence
whose items are mapped to single-item batches and are already tuples.
"""
def _process_item_data(
self,
item_data: tuple[T, ...],
item_index: int,
) -> tuple[T, ...]:
if self.post_transform is not None:
item_data = self.post_transform(item_data)
return item_data