1028 lines
38 KiB
Python
1028 lines
38 KiB
Python
"""
|
|
Domain-generic dataset base with attribute-based splitting and balancing
|
|
|
|
**Marginalizing out the modality layer**
|
|
|
|
With ``domain`` being an instance variable, one possible interpretation of
|
|
the object structures here is that one could completely abstract away
|
|
the domain model, defining only item structures and processing data. You
|
|
could have a single dataset definition for a particular concrete dataset,
|
|
and so long as we're talking about the same items, it can be instantiated
|
|
using *any domain*. You wouldn't need specific subclasses for disk or
|
|
network or in-memory structures; you can tell it directly at runtime.
|
|
|
|
That's an eventually possibility, anyway. As it stands, however, this is
|
|
effectively impossible:
|
|
|
|
You can't easily abstract the batch-to-item splitting process, i.e.,
|
|
``_process_batch_data()``. A list-based version of the dataset you're trying to
|
|
define might have an individual item tuple at every index, whereas a disk-based
|
|
version might have tuples batched across a few files. This can't reliably be
|
|
inferred, nor can it be pushed to the ``Domain``-level without needing equal
|
|
levels of specialization (you'd just end up needing the exact same structural
|
|
distinctions in the ``Domain`` hierarchy). So *somewhere* you need a batch
|
|
splitting implementation that is both item structure-dependent *and*
|
|
domain-dependent...the question is how dynamic you're willing to be about where
|
|
it comes from. Right now, we require this actually be defined in the
|
|
``_process_batch_data()`` method, meaning you'll need a specific ``Dataset``
|
|
class for each domain you want to support (e.g., ``MNISTDisk``, ``MNISTList``,
|
|
``MNISTNetwork``, etc), or at least for each domain where "interpreting" a
|
|
batch could possibly differ. This is a case where the interface is all that
|
|
enforces a distinction: if you've got two domains that can be counted on to
|
|
yield batches in the exact same way and can use the same processing, then you
|
|
could feasibly provide ``Domain`` objects from either at runtime and have no
|
|
issues. We're "structurally blind" to any differentiation beyond the URI and
|
|
resource types by design, so two different domain implementations with the same
|
|
type signature ``Domain[U, R]`` should be expected to work fine at runtime
|
|
(again, so long as they don't also need different batch processing), but that's
|
|
not affording us much flexibility, i.e., most of the time we'll still be
|
|
defining new dataset classes for each domain.
|
|
|
|
I initially flagged this as feasible, however, because one could imagine
|
|
accepting a batch processing method upon instantiation rather than structurally
|
|
bolting it into the ``Dataset`` definition. This would require knowledge of the
|
|
item structure ``I`` as well as the ``Domain[U, R]``, so such a function will
|
|
always have to be ``(I, U, R)``-dependent. It nevertheless would take out some
|
|
of the pain of having to define new dataset classes; instead, you'd just need
|
|
to define the batch processing method. I see this as a worse alternative to
|
|
just defining *inside* a safe context like a new dataset class: you know the
|
|
types you have to respect, and you stick that method exactly in a context where
|
|
it's understood. Freeing this up doesn't lighten the burden of processing
|
|
logic, it just changes *when* it has to be provided, and that's not worth much
|
|
(to me) in this case given the bump in complexity. (Taking this to the extreme:
|
|
you could supply *all* of an object's methods "dynamically" and glue them
|
|
together at runtime so long as they all played nice. But wherever you were
|
|
"laying them out" beforehand is exactly the job of a class to begin with, so
|
|
you don't end up with anything more dynamic. All we're really discussing here
|
|
is pushing around unavoidable complexity inside and outside of the "class
|
|
walls," and in the particular case of ``_process_batch_data()``, it feels much
|
|
better when it's on the inside.)
|
|
"""
|
|
|
|
import math
|
|
import random
|
|
import logging
|
|
from abc import abstractmethod
|
|
from copy import copy
|
|
from bisect import bisect
|
|
from typing import Unpack, TypedDict
|
|
from functools import lru_cache
|
|
from collections import defaultdict
|
|
from collections.abc import Callable, Iterator
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from trainlib.utils import job
|
|
from trainlib.domain import Domain, SequenceDomain
|
|
from trainlib.transform import Transform
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatasetKwargs[I](TypedDict, total=False):
|
|
pre_transform: Transform[I]
|
|
post_transform: Transform[I]
|
|
batch_cache_limit: int
|
|
preload: bool
|
|
num_workers: int
|
|
|
|
|
|
class BatchedDataset[U, R, I](Dataset):
|
|
"""
|
|
Generic dataset that dynamically pulls batched data from resources (e.g.,
|
|
files, remote locations, streams, etc).
|
|
|
|
The class is generic over a URI type ``U``, a resource type ``R`` (both of
|
|
which are used to concretize a domain ``Domain[U, R]``), and an item type
|
|
``I``.
|
|
|
|
**Batch and item processing flow**
|
|
|
|
.. code-block:: text
|
|
|
|
Domain -> [U] :: self._batch_uris = list(domain)
|
|
|
|
Grab all URIs from Domain iterators. This is made concrete early to
|
|
allow for Dataset sizing, and we need a Sequence representation to
|
|
map integer batch indices into Domains, i.e., when getting the
|
|
corresponding URI:
|
|
|
|
batch_uri = self._batch_uris[batch_index]
|
|
|
|
We let Domains implement iterators over their URIs, but explicitly
|
|
exhaust when initializing Datasets.
|
|
|
|
U -> R :: batch_data = self.domain[batch_uri]
|
|
|
|
Retrieve resource from domain. Resources are viewed as batched
|
|
data, even if only wrapping single items (happens in trivial
|
|
settings).
|
|
|
|
R -> [I] :: self._process_batch_data(batch_data, batch_index)
|
|
|
|
Possibly domain-specific batch processing of resource data into
|
|
explicit Sequence-like structures of items, each of which is
|
|
subject to the provided pre_transform. Processed batches at this
|
|
stage are cached (if enabled).
|
|
|
|
[I] -> I :: self.get_batch(batch_index)[index_in_batch]
|
|
|
|
Select individual items from batches in _get_item. At this stage,
|
|
items are in intermediate states and pulled from the cached
|
|
batches.
|
|
|
|
I -> I :: self._process_item_data(item_data, index)
|
|
|
|
Produce final items with __getitem__, getting intermediate items
|
|
via _get_item and applying the provided post_transform.
|
|
|
|
.. note::
|
|
|
|
As far as positioning, this class is meant to play nice with PyTorch
|
|
DataLoaders, hence the inheritance from ``torch.Dataset``. The value
|
|
add for this over the ``torch.Dataset`` base is almost entirely in the
|
|
logic it implements to map out of *batched resources* that are holding
|
|
data, and flattening it out into typical dataset items. There are also
|
|
some QoL features when it comes to splitting and balancing samples.
|
|
|
|
.. note::
|
|
|
|
Even though ``Domains`` implement iterators over their URIs, this
|
|
doesn't imply a ``BatchedDataset`` is iterable. This just means we can
|
|
walk over the resources that provide data, but we don't necessarily
|
|
presuppose an ordered walk over samples within batches. Point being:
|
|
``torch.Dataset``, not ``torch.IterableDataset``, is the appropriate
|
|
superclass, even when we're working around iterable ``Domains``.
|
|
|
|
.. note::
|
|
|
|
Transforms are expected to operate on ``I``-items and produce
|
|
``I``-items. They shouldn't be the "introducers" of ``I`` types from
|
|
some other intermediate representation, nor should they map from ``I``
|
|
to something else. Point being: the dataset definition should be able
|
|
to map resources ``R`` to ``I`` without a transform: that much should
|
|
be baked into the class definition. If you find you're expecting the
|
|
transform to do that for you, you should consider pulling in some
|
|
common structure across the allowed transforms and make it a fixed part
|
|
of the class.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
domain: Domain[U, R],
|
|
pre_transform: Transform[I] | None = None,
|
|
post_transform: Transform[I] | None = None,
|
|
batch_cache_limit: int | None = None,
|
|
preload: bool = False,
|
|
num_workers: int = 1,
|
|
) -> None:
|
|
"""
|
|
Parameters:
|
|
domain: ``Domain`` object providing access to batched data
|
|
pre_transform: transform to apply over items during loading (in
|
|
``_process_batch_data()``), i.e., *before* going into
|
|
persistent storage
|
|
post_transform: transform to apply just prior to returning an item
|
|
(in ``_process_item_data()``), i.e., only *after* retrieval
|
|
from persistent storage
|
|
batch_cache_limit: the max number of max batches to cache at any
|
|
one time
|
|
preload: whether to load all data into memory during instantiation
|
|
num_workers: number of workers to use when preloading data
|
|
"""
|
|
|
|
self.domain = domain
|
|
self.pre_transform = pre_transform
|
|
self.post_transform = post_transform
|
|
self.batch_cache_limit = batch_cache_limit
|
|
self.num_workers = num_workers
|
|
|
|
logger.info("Fetching URIs...")
|
|
self._batch_uris: list[U] = list(domain)
|
|
|
|
self._indices: list[int] | None = None
|
|
self._dataset_len: int | None = None
|
|
self._num_batches: int = len(domain)
|
|
|
|
self.get_batch: Callable[[int], list[I]] = lru_cache(
|
|
maxsize=batch_cache_limit
|
|
)(self._get_batch)
|
|
|
|
if preload:
|
|
self.load_all(num_workers=num_workers)
|
|
|
|
@abstractmethod
|
|
def _get_dataset_len(self) -> int:
|
|
"""
|
|
Calculate the total dataset size in units of samples (not batches).
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
|
"""
|
|
Return the index of the batch containing the item at the provided item
|
|
index, and the index of the item within that batch.
|
|
|
|
The behavior of this method can vary depending on what we know about
|
|
batch sizes, and should therefore be implemented by inheriting classes.
|
|
|
|
Parameters:
|
|
item_index: index of item
|
|
|
|
Returns:
|
|
batch_index: int
|
|
index_in_batch: int
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _process_batch_data(
|
|
self,
|
|
batch_data: R,
|
|
batch_index: int,
|
|
) -> list[I]:
|
|
"""
|
|
Process raw domain resource data (e.g., parse JSON or load tensors) and
|
|
split accordingly, such that the returned batch is a collection of
|
|
``I`` items.
|
|
|
|
If an inheriting class wants to allow dynamic transforms, this is the
|
|
place to use a provided ``pre_transform``; the collection of items
|
|
produced by this method are cached as a batch, so results from such a
|
|
transform will be stored.
|
|
|
|
Parameters:
|
|
batch_data: tuple of resource data
|
|
batch_index: index of batch
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def _process_item_data(
|
|
self,
|
|
item_data: I,
|
|
item_index: int,
|
|
) -> I:
|
|
"""
|
|
Process individual items and produce final tuples.
|
|
|
|
If an inheriting class wants to allow dynamic transforms, this is the
|
|
place to use a provided ``post_transform``; items are pulled from the
|
|
cache (if enabled) and processed before being returned as the final
|
|
tuple outputs (so this processing is not persistent).
|
|
|
|
Parameters:
|
|
item_data: item data
|
|
item_index: index of item
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
def _get_item(self, item_index: int) -> I:
|
|
"""
|
|
Get the item data and zip with the item header.
|
|
|
|
Items should be the most granular representation of dataset samples
|
|
with maximum detail (i.e., yield a all available information). An
|
|
iterator over these representations across all samples can be retrieved
|
|
with `.items()`.
|
|
|
|
Note that return values from `__getitem__()` are "cleaned up" versions
|
|
of this representation, with minimal info needed for training.
|
|
|
|
Parameters:
|
|
item_index: index of item
|
|
"""
|
|
|
|
if item_index >= len(self):
|
|
raise IndexError
|
|
|
|
# alt indices redefine index count
|
|
item_index = self.indices[item_index]
|
|
batch_index, index_in_batch = self._get_batch_for_item(item_index)
|
|
|
|
return self.get_batch(batch_index)[index_in_batch]
|
|
|
|
def _get_batch(self, batch_index: int) -> list[I]:
|
|
"""
|
|
Return the batch data for the provided index.
|
|
|
|
Note that we require a list return type. This is where the rubber meets
|
|
the road in terms of expanding batches: the outputs here get cached, if
|
|
caching is enabled. If we were to defer batch expansion to some later
|
|
stage, that caching will be more or less worthless. For instance, if
|
|
``._process_batch_data()`` was an iterator, at best the batch
|
|
processing logic could be delayed, but then you'd either 1) cache the
|
|
iterator reference, or 2) have to further delay caching until
|
|
post-batch processing. Further, ``._process_batch_data()`` can (and
|
|
often does) depend on the entire batch, so you can't handle that
|
|
item-wise: you've got to pass all batch data in and can't act on
|
|
slices, so doing this would be irrelevant anyway.
|
|
|
|
How about iterators out from ``_read_resources()``? Then your
|
|
``_process_batch_data()`` can iterate as needed and do the work later?
|
|
The point here is that this distinction is irrelevant because you're
|
|
reading resources and processing the data here in the same method:
|
|
they're always connected, and nothing would notice if you waited
|
|
between steps. The only way this could matter is if you split the
|
|
resource reading and batch processing steps across methods, but when it
|
|
actually comes to accessing/caching the batch, you'd have to expand any
|
|
delayed reads here. There's no way around needing to see all batch data
|
|
at once here, and we don't want to make that ambiguous: ``list`` output
|
|
type it is.
|
|
|
|
Parameters:
|
|
batch_index: index of batch
|
|
"""
|
|
|
|
logger.debug("Batch cache miss, reading from root...")
|
|
|
|
if batch_index >= self._num_batches:
|
|
raise IndexError
|
|
|
|
batch_uri = self._batch_uris[batch_index]
|
|
batch_data = self.domain[batch_uri]
|
|
|
|
return self._process_batch_data(batch_data, batch_index)
|
|
|
|
def load_all(self, num_workers: int | None = None) -> list[list[I]]:
|
|
"""
|
|
Preload all data batches into the cache.
|
|
|
|
Can be useful when dynamically pulling data (as it's requested) isn't
|
|
desired. Requires that `cache_sample_limit=None`, i.e., the cache won't
|
|
continually remove previous batches as they're loaded.
|
|
|
|
Parameters:
|
|
num_workers: number of parallel workers to use for data loading
|
|
"""
|
|
|
|
assert self.batch_cache_limit is None, "Preloading under cache limit"
|
|
|
|
if num_workers is None:
|
|
num_workers = self.num_workers
|
|
|
|
thread_pool = ThreadPoolExecutor(max_workers=num_workers)
|
|
|
|
futures = []
|
|
for batch_index in range(self._num_batches):
|
|
future = thread_pool.submit(
|
|
self.get_batch,
|
|
batch_index,
|
|
)
|
|
futures.append(future)
|
|
|
|
job.process_futures(futures, "Loading dataset batches", "batch")
|
|
thread_pool.shutdown(wait=True)
|
|
|
|
return [future.result() for future in futures]
|
|
|
|
def split(
|
|
self,
|
|
fracs: list[float],
|
|
dataset: "BatchedDataset | None" = None,
|
|
by_attr: str | list[str | None] | None = None,
|
|
shuffle_strata: bool = True,
|
|
) -> list["BatchedDataset"]:
|
|
"""
|
|
Split dataset into fractional pieces by data attribute.
|
|
|
|
If ``by_attr`` is None, recovers typical fractional splitting of
|
|
dataset items, partitioning by size. Using None anywhere will index
|
|
each item into its own bucket, i.e., by its index. For instance:
|
|
|
|
- Splits on the attribute such that each subset contains entire strata
|
|
of the attribute. "Homogeneity within clusters:"
|
|
|
|
.. code-block::
|
|
|
|
by_attr=["color"] -> {("red", 1), ("red", 2)},
|
|
{("blue", 1), ("blue", 2)}
|
|
|
|
- Stratifies by attribute and then splits "by index" within, uniformly
|
|
grabbing samples across strata to form new clusters. "Homogeneity
|
|
across clusters"
|
|
|
|
.. code-block::
|
|
|
|
by_attr=["color", None] -> {("red", 1), ("blue", 1)},
|
|
{("red", 2), ("blue", 2)}
|
|
|
|
Note that the final list of Subsets returned are built from shallow
|
|
copies of the underlying dataset (i.e., ``self``) to allow manual
|
|
intervention with dataset attributes (e.g., setting the splits to have
|
|
different ``transforms``). This is subject to possibly unexpected
|
|
behavior if re-caching data or you need a true copy of all data in
|
|
memory, but should otherwise leave most interactions unchanged.
|
|
|
|
Parameters:
|
|
frac: split fractions for datasets
|
|
dataset: dataset to split, defaults to ``self``. Facilitates
|
|
recursive splitting when multi-attribute splits are needed.
|
|
by_attr: attribute or attributes to use when grouping strata for
|
|
dataset splits. Defaults to ``None``, which will use item
|
|
indices.
|
|
shuffle_strata: shuffle the strata order before split is drawn. We
|
|
parameterize this because a Dataloader-level shuffle operation
|
|
will only change the order of the indices in the resulting
|
|
splits; only a shuffle of the strata order can change the
|
|
actual content of the splits themselves.
|
|
"""
|
|
|
|
if by_attr == []:
|
|
raise ValueError("Cannot parse empty value list")
|
|
|
|
assert (
|
|
math.isclose(sum(fracs), 1) and sum(fracs) <= 1
|
|
), "Fractions do not sum to 1"
|
|
|
|
if isinstance(by_attr, str) or by_attr is None:
|
|
by_attr = [by_attr]
|
|
|
|
if dataset is None:
|
|
dataset = self
|
|
# dataset = DictDataset([
|
|
# self._get_item(i) for i in range(len(self))
|
|
# ])
|
|
|
|
# group samples by specified attr
|
|
attr_dict = defaultdict(list)
|
|
attr_key, by_attr = by_attr[0], by_attr[1:]
|
|
# for i in range(len(dataset)):
|
|
for i, item in enumerate(dataset.items()):
|
|
# item = dataset[i]
|
|
if attr_key is None:
|
|
attr_val = i
|
|
elif attr_key in item:
|
|
attr_val = item[attr_key]
|
|
else:
|
|
raise IndexError(f"Attribute {attr_key} not in dataset item")
|
|
attr_dict[attr_val].append(i)
|
|
|
|
if by_attr == []:
|
|
attr_keys = list(attr_dict.keys())
|
|
|
|
# shuffle keys; randomized group-level split
|
|
if shuffle_strata:
|
|
random.shuffle(attr_keys)
|
|
|
|
# considering: defer to dataloader shuffle param; should have same
|
|
# effect shuffle values; has no impact on where the split is drawn
|
|
# for attr_vals in attr_dict.values():
|
|
# random.shuffle(attr_vals)
|
|
|
|
# fractionally split over attribute keys
|
|
offset, splits = 0, []
|
|
for frac in fracs[:-1]:
|
|
frac_indices = []
|
|
frac_size = int(frac * len(attr_keys))
|
|
for j in range(offset, offset + frac_size):
|
|
j_indices = attr_dict.pop(attr_keys[j])
|
|
frac_indices.extend(j_indices)
|
|
offset += frac_size
|
|
splits.append(frac_indices)
|
|
|
|
rem_indices = []
|
|
for r_indices in attr_dict.values():
|
|
rem_indices.extend(r_indices)
|
|
splits.append(rem_indices)
|
|
|
|
subsets = []
|
|
for split in splits:
|
|
subset = copy(dataset)
|
|
subset.indices = split
|
|
subsets.append(subset)
|
|
|
|
return subsets
|
|
else:
|
|
splits = [[] for _ in range(len(fracs))]
|
|
for index_split in attr_dict.values():
|
|
# subset = Subset(dataset, index_split)
|
|
subset = copy(dataset)
|
|
subset.indices = index_split
|
|
subset_splits = self.split(
|
|
fracs, subset, by_attr, shuffle_strata
|
|
)
|
|
|
|
# unpack stratified splits
|
|
for i, subset_split in enumerate(subset_splits):
|
|
# splits[i].extend([
|
|
# index_split[s] for s in subset_split.indices
|
|
# ])
|
|
splits[i].extend(subset_split.indices)
|
|
|
|
subsets = []
|
|
for split in splits:
|
|
subset = copy(dataset)
|
|
subset.reset_indices()
|
|
subset.indices = split
|
|
subsets.append(subset)
|
|
|
|
return subsets
|
|
|
|
# considering: defer to dataloader shuffle param; should have same
|
|
# effect shuffle each split after merging; may otherwise be homogenous
|
|
# for split in splits:
|
|
# random.shuffle(split)
|
|
|
|
# return [Subset(copy(self), split) for split in splits]
|
|
|
|
def balance(
|
|
self,
|
|
dataset: "BatchedSubset[U, R, I] | None" = None,
|
|
by_attr: str | list[str | None] | None = None,
|
|
split_min_sizes: list[int] | None = None,
|
|
split_max_sizes: list[int] | None = None,
|
|
shuffle_strata: bool = True,
|
|
) -> None:
|
|
"""
|
|
Balance the distribution of provided attributes over dataset items.
|
|
|
|
This method sets the indices over the dataset according to the result
|
|
of the rebalancing. The indices are produced by the recursive
|
|
``_balance()`` method, which is necessarily separate due to the need
|
|
for a contained recursive approach that doesn't change the underlying
|
|
dataset during execution.
|
|
|
|
Parameters:
|
|
dataset: dataset to split, defaults to ``self``. Facilitates
|
|
recursive splitting when multi-attribute splits are needed.
|
|
by_attr: attribute or attributes to use when grouping strata for
|
|
dataset splits. Defaults to ``None``, which will use item
|
|
indices.
|
|
split_min_sizes: minimum allowed sizes of splits. Must have the
|
|
same length as ``by_attr``.
|
|
split_max_sizes: maximum allowed sizes of splits. Must have the
|
|
same length as ``by_attr``.
|
|
shuffle_strata: shuffle the strata order before split is drawn. We
|
|
parameterize this because a Dataloader-level shuffle operation
|
|
will only change the order of the indices in the resulting
|
|
splits; only a shuffle of the strata order can change the
|
|
actual content of the splits themselves.
|
|
"""
|
|
|
|
self.indices = self._balance(
|
|
dataset,
|
|
by_attr,
|
|
split_min_sizes,
|
|
split_max_sizes,
|
|
shuffle_strata,
|
|
)
|
|
|
|
def _balance(
|
|
self,
|
|
dataset: "BatchedSubset[U, R, I] | None" = None,
|
|
by_attr: str | list[str | None] | None = None,
|
|
split_min_sizes: list[int] | None = None,
|
|
split_max_sizes: list[int] | None = None,
|
|
shuffle_strata: bool = True,
|
|
) -> list[int]:
|
|
"""
|
|
Recursive balancing of items by attribute.
|
|
|
|
.. note::
|
|
|
|
Behavior is a little odd for nested behavior; not exactly perfectly
|
|
uniform throughout. This is a little difficult: you can't exactly
|
|
know ahead of time the size of the subgroups across splits
|
|
|
|
Parameters:
|
|
dataset: dataset to split, defaults to ``self``. Facilitates
|
|
recursive splitting when multi-attribute splits are needed.
|
|
by_attr: attribute or attributes to use when grouping strata for
|
|
dataset splits. Defaults to ``None``, which will use item
|
|
indices.
|
|
split_min_sizes: minimum allowed sizes of splits. Must have the
|
|
same length as ``by_attr``.
|
|
split_max_sizes: maximum allowed sizes of splits. Must have the
|
|
same length as ``by_attr``.
|
|
shuffle_strata: shuffle the strata order before split is drawn. We
|
|
parameterize this because a Dataloader-level shuffle operation
|
|
will only change the order of the indices in the resulting
|
|
splits; only a shuffle of the strata order can change the
|
|
actual content of the splits themselves.
|
|
"""
|
|
|
|
if by_attr == []:
|
|
raise ValueError("Cannot parse empty value list")
|
|
|
|
if isinstance(by_attr, str) or by_attr is None:
|
|
by_attr = [by_attr]
|
|
|
|
if dataset is None:
|
|
dataset = BatchedSubset(self, self.indices)
|
|
|
|
if split_min_sizes == [] or split_min_sizes is None:
|
|
split_min_sizes = [0]
|
|
if split_max_sizes == [] or split_max_sizes is None:
|
|
split_max_sizes = [len(dataset)]
|
|
|
|
# group samples by specified attr
|
|
attr_dict = defaultdict(list)
|
|
attr_key, by_attr = by_attr[0], by_attr[1:]
|
|
for i in range(len(dataset)):
|
|
item = dataset[i]
|
|
if attr_key is None:
|
|
attr_val = i
|
|
elif attr_key in item:
|
|
attr_val = getattr(item, attr_key)
|
|
else:
|
|
raise IndexError(f"Attribute {attr_key} not in dataset item")
|
|
attr_dict[attr_val].append(i)
|
|
|
|
subset_splits = []
|
|
split_min_size, split_min_sizes = (
|
|
split_min_sizes[0],
|
|
split_min_sizes[1:]
|
|
)
|
|
split_max_size, split_max_sizes = (
|
|
split_max_sizes[0],
|
|
split_max_sizes[1:]
|
|
)
|
|
for split_indices in attr_dict.values():
|
|
if by_attr != []:
|
|
subset_indices = self._balance(
|
|
BatchedSubset(dataset, split_indices),
|
|
by_attr,
|
|
split_min_sizes,
|
|
split_max_sizes,
|
|
shuffle_strata,
|
|
)
|
|
split_indices = [split_indices[s] for s in subset_indices]
|
|
subset_splits.append(split_indices)
|
|
|
|
# shuffle splits; randomized group-level split
|
|
if shuffle_strata:
|
|
random.shuffle(subset_splits)
|
|
|
|
# note: split_min_size is smallest allowed, min_split_size is smallest
|
|
# observed
|
|
valid_splits = [
|
|
ss for ss in subset_splits
|
|
if len(ss) >= split_min_size
|
|
]
|
|
min_split_size = min(len(split) for split in valid_splits)
|
|
|
|
subset_indices = []
|
|
for split_indices in valid_splits:
|
|
# if shuffle_strata:
|
|
# random.shuffle(split_indices)
|
|
subset_indices.extend(
|
|
split_indices[: min(min_split_size, split_max_size)]
|
|
)
|
|
|
|
# print(f"{attr_dict.keys()=}")
|
|
# print(f"{[len(s) for s in valid_splits]=}")
|
|
# print(f"{min_split_size=}")
|
|
|
|
return subset_indices
|
|
|
|
@property
|
|
def indices(self) -> list[int]:
|
|
if self._indices is None:
|
|
self._indices = list(range(len(self)))
|
|
return self._indices
|
|
|
|
@indices.setter
|
|
def indices(self, indices: list[int]) -> None:
|
|
"""
|
|
Note: this logic facilitates nested re-indexing over the same base
|
|
dataset. The underlying data remain the same, but when indices get set,
|
|
you're effectively applying a mask over any existing indices, always
|
|
operating *relative* to the existing mask.
|
|
|
|
Parameters:
|
|
indices: list of indices to set
|
|
"""
|
|
|
|
# manually set new size
|
|
self._dataset_len = len(indices)
|
|
|
|
# note: this is a little tricky and compact; follow what happens when
|
|
# _indices aren't already set
|
|
self._indices = [
|
|
self.indices[index]
|
|
for index in indices
|
|
]
|
|
|
|
def reset_indices(self) -> None:
|
|
self._indices = None
|
|
self._dataset_len = None
|
|
|
|
def items(self) -> Iterator[I]:
|
|
for i in range(len(self)):
|
|
yield self._get_item(i)
|
|
|
|
def __len__(self) -> int:
|
|
if self._dataset_len is None:
|
|
self._dataset_len = self._get_dataset_len()
|
|
|
|
return self._dataset_len
|
|
|
|
def __getitem__(self, index: int) -> I:
|
|
"""
|
|
Get the dataset item at the specified index.
|
|
|
|
Parameters:
|
|
index: index of item to retrieve
|
|
"""
|
|
|
|
item_data = self._get_item(index)
|
|
index = self.indices[index]
|
|
|
|
return self._process_item_data(item_data, index)
|
|
|
|
def __iter__(self) -> Iterator[I]:
|
|
"""
|
|
Note: this method isn't technically needed given ``__getitem__`` is
|
|
defined and we operate cleanly over integer indices 0..(N-1), so even
|
|
without an explicit ``__iter__``, Python will fall back to a reliable
|
|
iteration mechanism. We nevertheless implement the trivial logic below
|
|
to convey intent and meet static type checks for iterables.
|
|
"""
|
|
|
|
return (self[i] for i in range(len(self)))
|
|
|
|
|
|
class CompositeBatchedDataset[U, R, I](BatchedDataset[U, R, I]):
|
|
"""
|
|
Dataset class for wrapping individual datasets.
|
|
|
|
.. note::
|
|
Because this remains a valid ``BatchedDataset``, we re-thread the
|
|
generic type variables through the set of composed datasets. That is,
|
|
they must have a common domain type ``Domain[U, R]``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
datasets: list[BatchedDataset[U, R, I]],
|
|
) -> None:
|
|
"""
|
|
Parameters:
|
|
datasets: list of datasets
|
|
"""
|
|
|
|
self.datasets = datasets
|
|
|
|
self._indices: list[int] | None = None
|
|
self._dataset_len: int | None = None
|
|
|
|
self._item_psum: list[int] = []
|
|
self._batch_psum: list[int] = []
|
|
|
|
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
|
|
if not arr:
|
|
return []
|
|
|
|
prefix_sum = [0] * (len(arr) + 1)
|
|
for i in range(len(arr)):
|
|
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
|
|
|
|
return prefix_sum
|
|
|
|
def _get_dataset_for_item(self, item_index: int) -> tuple[int, int]:
|
|
dataset_index = bisect(self._item_psum, item_index) - 1
|
|
index_in_dataset = item_index - self._item_psum[dataset_index]
|
|
|
|
return dataset_index, index_in_dataset
|
|
|
|
def _get_dataset_for_batch(self, batch_index: int) -> tuple[int, int]:
|
|
dataset_index = bisect(self._batch_psum, batch_index) - 1
|
|
index_in_dataset = batch_index - self._batch_psum[dataset_index]
|
|
|
|
return dataset_index, index_in_dataset
|
|
|
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
|
index_pair = self._get_dataset_for_item(item_index)
|
|
dataset_index, index_in_dataset = index_pair
|
|
dataset = self.datasets[dataset_index]
|
|
|
|
return dataset._get_batch_for_item(index_in_dataset)
|
|
|
|
def _get_dataset_len(self) -> int:
|
|
self.load_all()
|
|
|
|
dataset_batch_sizes = [len(dataset) for dataset in self.datasets]
|
|
dataset_sizes = [dataset._num_batches for dataset in self.datasets]
|
|
|
|
# this method will only be ran once; set this instance var
|
|
self._item_psum = self._compute_prefix_sum(dataset_batch_sizes)
|
|
self._batch_psum = self._compute_prefix_sum(dataset_sizes)
|
|
|
|
return self._item_psum[-1]
|
|
|
|
def _process_batch_data(
|
|
self,
|
|
batch_data: R,
|
|
batch_index: int,
|
|
) -> list[I]:
|
|
index_pair = self._get_dataset_for_item(batch_index)
|
|
dataset_index, index_in_dataset = index_pair
|
|
dataset = self.datasets[dataset_index]
|
|
|
|
return dataset._process_batch_data(batch_data, index_in_dataset)
|
|
|
|
def _process_item_data(self, item_data: I, item_index: int) -> I:
|
|
index_pair = self._get_dataset_for_item(item_index)
|
|
dataset_index, index_in_dataset = index_pair
|
|
dataset = self.datasets[dataset_index]
|
|
|
|
return dataset._process_item_data(item_data, index_in_dataset)
|
|
|
|
def _get_item(self, item_index: int) -> I:
|
|
item_index = self.indices[item_index]
|
|
|
|
index_pair = self._get_dataset_for_item(item_index)
|
|
dataset_index, index_in_dataset = index_pair
|
|
dataset = self.datasets[dataset_index]
|
|
|
|
return dataset._get_item(index_in_dataset)
|
|
|
|
def _get_batch(self, batch_index: int) -> list[I]:
|
|
index_pair = self._get_dataset_for_item(batch_index)
|
|
dataset_index, index_in_dataset = index_pair
|
|
dataset = self.datasets[dataset_index]
|
|
|
|
return dataset._get_batch(index_in_dataset)
|
|
|
|
def load_all(
|
|
self,
|
|
num_workers: int | None = None
|
|
) -> list[list[I]]:
|
|
batches = []
|
|
for dataset in self.datasets:
|
|
batches.extend(dataset.load_all(num_workers))
|
|
return batches
|
|
|
|
@property
|
|
def pre_transform(self) -> list[Transform[I] | None]:
|
|
return [dataset.pre_transform for dataset in self.datasets]
|
|
|
|
@pre_transform.setter
|
|
def pre_transform(self, pre_transform: Transform[I]) -> None:
|
|
for dataset in self.datasets:
|
|
dataset.pre_transform = pre_transform
|
|
|
|
@property
|
|
def post_transform(self) -> list[Transform[I] | None]:
|
|
return [dataset.post_transform for dataset in self.datasets]
|
|
|
|
@post_transform.setter
|
|
def post_transform(self, post_transform: Transform[I]) -> None:
|
|
for dataset in self.datasets:
|
|
dataset.post_transform = post_transform
|
|
|
|
|
|
class BatchedSubset[U, R, I](BatchedDataset[U, R, I]):
|
|
def __init__(
|
|
self,
|
|
dataset: BatchedDataset[U, R, I],
|
|
indices: list[int],
|
|
) -> None:
|
|
self.dataset = dataset
|
|
self.indices = indices
|
|
|
|
def _get_item(self, item_index: int) -> I:
|
|
"""
|
|
Subset indices are "reset" in its context. Simply passes through
|
|
"""
|
|
|
|
return self.dataset._get_item(self.indices[item_index])
|
|
|
|
def _get_dataset_len(self) -> int:
|
|
return len(self.indices)
|
|
|
|
|
|
class HomogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
|
"""
|
|
Batched dataset where batches are equally sized.
|
|
|
|
Subclass from this base when you can count on the reference data being
|
|
prepared with fixed size batches (up to the last batch), e.g., that which
|
|
has been prepared with a `Packer`. This can greatly improve measurement
|
|
time of the dataset size by preventing the need for reading all batch files
|
|
upfront, and reduces the cost of identifying item batches from O(log n) to
|
|
O(1).
|
|
|
|
Methods left for inheriting classes:
|
|
|
|
- ``_process_item_data()``: item processing
|
|
- ``_process_batch_data()``: batch processing
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
domain: Domain[U, R],
|
|
**kwargs: Unpack[DatasetKwargs],
|
|
) -> None:
|
|
super().__init__(domain, **kwargs)
|
|
|
|
# determine batch size across dataset, along w/ possible partial final
|
|
# batch
|
|
bsize = rem = len(self.get_batch(self._num_batches - 1))
|
|
if self._num_batches > 1:
|
|
bsize, rem = len(self.get_batch(self._num_batches - 2)), bsize
|
|
|
|
self._batch_size: int = bsize
|
|
self._batch_rem: int = rem
|
|
|
|
def _get_dataset_len(self) -> int:
|
|
return self._batch_size * (self._num_batches - 1) + self._batch_rem
|
|
|
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
|
return item_index // self._batch_size, item_index % self._batch_size
|
|
|
|
|
|
class HeterogenousDataset[U, R, I](BatchedDataset[U, R, I]):
|
|
"""
|
|
Batched dataset where batches may have arbitrary size.
|
|
|
|
Methods left for inheriting classes:
|
|
|
|
- ``_process_item_data()``: item processing
|
|
- ``_process_batch_data()``: batch processing
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
domain: Domain[U, R],
|
|
**kwargs: Unpack[DatasetKwargs],
|
|
) -> None:
|
|
super().__init__(domain, **kwargs)
|
|
|
|
self._batch_size_psum: list[int] = []
|
|
|
|
def _compute_prefix_sum(self, arr: list[int]) -> list[int]:
|
|
if not arr:
|
|
return []
|
|
|
|
prefix_sum = [0] * (len(arr) + 1)
|
|
for i in range(len(arr)):
|
|
prefix_sum[i + 1] = prefix_sum[i] + arr[i]
|
|
|
|
return prefix_sum
|
|
|
|
def _get_dataset_len(self) -> int:
|
|
# type error below: no idea why this is flagged
|
|
batches = self.load_all()
|
|
batch_sizes = [len(batch) for batch in batches]
|
|
|
|
# this method will only be ran once; set this instance var
|
|
self._batch_size_psum = self._compute_prefix_sum(batch_sizes)
|
|
|
|
return self._batch_size_psum[-1]
|
|
|
|
def _get_batch_for_item(self, item_index: int) -> tuple[int, int]:
|
|
batch_index = bisect(self._batch_size_psum, item_index) - 1
|
|
index_in_batch = item_index - self._batch_size_psum[batch_index]
|
|
|
|
return batch_index, index_in_batch
|
|
|
|
|
|
class SequenceDataset[I](HomogenousDataset[int, I, I]):
|
|
"""
|
|
Trivial dataset skeleton for sequence domains.
|
|
|
|
``I``-typed sequence items map directly to dataset items. To produce a
|
|
fully concrete dataset, one still needs to define ``_process_item_data()``
|
|
to map from ``I``-items to tuples.
|
|
"""
|
|
|
|
domain: SequenceDomain[I]
|
|
|
|
def _process_batch_data(
|
|
self,
|
|
batch_data: I,
|
|
batch_index: int,
|
|
) -> list[I]:
|
|
if self.pre_transform is not None:
|
|
batch_data = self.pre_transform(batch_data)
|
|
|
|
return [batch_data]
|
|
|
|
|
|
class TupleDataset[T](SequenceDataset[tuple[T, ...]]):
|
|
"""
|
|
Trivial sequence-of-tuples dataset.
|
|
|
|
This is the most straightforward line to a concrete dataset from a
|
|
``BatchedDataset`` base class. That is: the underlying domain is a sequence
|
|
whose items are mapped to single-item batches and are already tuples.
|
|
"""
|
|
|
|
def _process_item_data(
|
|
self,
|
|
item_data: tuple[T, ...],
|
|
item_index: int,
|
|
) -> tuple[T, ...]:
|
|
if self.post_transform is not None:
|
|
item_data = self.post_transform(item_data)
|
|
|
|
return item_data
|