From 4c98f7b6293cf83e1fe12afffe59608e32727457 Mon Sep 17 00:00:00 2001
From: "Sam G." <samgriesemer@gmail.com>
Date: Wed, 1 May 2024 20:02:14 -0700
Subject: [PATCH] add key-group collation uniqueness, fix dynamicism in Mapper
 collection

---
 co3/co3.py        | 36 ++++++++++++++------------
 co3/mapper.py     | 65 +++++++++++++++++++++++++++++++++++------------
 tests/test_co3.py |  9 +++----
 3 files changed, 72 insertions(+), 38 deletions(-)

diff --git a/co3/co3.py b/co3/co3.py
index 62baa64..5ccd5f4 100644
--- a/co3/co3.py
+++ b/co3/co3.py
@@ -150,27 +150,17 @@ class FormatRegistryMeta(type):
     Metaclass handling collation registry at the class level.
     '''
     def __new__(cls, name, bases, attrs):
-        key_registry = {}
-        group_registry = defaultdict(list)
+        key_registry = defaultdict(dict)
+        group_registry = defaultdict(set)
 
         def register_action(method):
             nonlocal key_registry, group_registry
 
             if hasattr(method, '_collation_data'):
                 key, groups = method._collation_data
-
-                if key is None:
-                    # only add a "None" entry if there is _some_ implicit group
-                    if None not in key_registry:
-                        key_registry[None] = {}
-
-                    # only a single group possible here
-                    key_registry[None][groups[0]] = method
-                else:
-                    key_registry[key] = (method, groups)
-
                 for group in groups:
-                    group_registry[group].append(key)
+                    key_registry[key][group] = method
+                    group_registry[group].add(key)
 
         # add registered superclass methods; iterate over bases (usually just one), then
         # that base's chain down (reversed), then methods from each subclass
@@ -244,6 +234,14 @@ class CO3(metaclass=FormatRegistryMeta):
         return {}
 
     def collate(self, key, group=None, *args, **kwargs):
+        '''
+        Note:
+            This method is sensitive to group specification. By default, the provided key
+            will be checked against the default ``None`` group, even if that key is only
+            attached to non-default groups. Collation actions are unique on key-group
+            pairs, so more specificity is generally required to correctly execute desired
+            actions (otherwise, rely more heavily on the default group).
+        '''
         if key is None:
             return None
 
@@ -258,13 +256,19 @@ class CO3(metaclass=FormatRegistryMeta):
             method = self.key_registry[None].get(group)
             if method is None:
                 logger.debug(
-                    f'Collation key "{key}" not registered and group {group} not implicit'
+                    f'Collation key "{key}" not registered and group "{group}" not implicit'
                 )
                 return None
 
             return method(self, key, *args, **kwargs)
         else:
-            method = self.key_registry[key][0]
+            method = self.key_registry[key].get(group)
+            if method is None:
+                logger.debug(
+                    f'Collation key "{key}" registered, but group "{group}" is not available'
+                )
+                return None
+
             return method(self, *args, **kwargs)
 
 
diff --git a/co3/mapper.py b/co3/mapper.py
index 98cb02d..46723f8 100644
--- a/co3/mapper.py
+++ b/co3/mapper.py
@@ -213,16 +213,58 @@ class Mapper[C: Component]:
         Parameters:
             obj: CO3 instance to collect from
             keys: keys for actions to collect from
-            group: action group names to run all actions for
+            group: group contexts for the keys to collect from. If None, explicit group
+                   contexts registered for the keys will be inferred (but implicit groups
+                   will not be detected).
 
-        Returns: dict with keys and values relevant for associated SQLite tables
+        Returns: collector receipts for staged inserts
         '''
         # default is to have no actions
         if keys is None:
             keys = []
             #keys = list(obj.key_registry.keys())
 
+        collation_data = defaultdict(dict)
+        for key in keys:
+            # keys must be defined
+            if key is None:
+                continue
+
+            # if groups not specified, dynamically grab those explicitly attached groups
+            # for each key
+            group_dict = {}
+            if groups is None:
+                group_dict = obj.key_registry.get(key, {})
+            else:
+                for group in groups:
+                    group_dict[group] = obj.key_registry.get(key, {}).get(group)
+
+            # method regroup: under key, index by method and run once per
+            method_groups = defaultdict(list)
+            for group_name, group_method in group_dict.items():
+                method_groups[group_method].append(group_name)
+
+            # collate for method equivalence classes; only need on representative group to
+            # pass to CO3.collate to call the method
+            key_collation_data = {}
+            for collation_method, collation_groups in method_groups.items():
+                key_method_collation_data = obj.collate(key, group=collation_groups[0])
+
+                for collation_group in collation_groups:
+                    # gather connective data for collation components
+                    # -> we do this here as it's obj dependent
+                    connective_data = obj.collation_attributes(key, collation_group)
+
+                    key_collation_data[collation_group] = {
+                        **connective_data,
+                        **key_method_collation_data, 
+                    }
+
+            collation_data[key] = key_collation_data
+
         receipts = []
+        attributes = obj.attributes
+
         for _cls in reversed(obj.__class__.__mro__[:-2]):
             attribute_component = self.get_attr_comp(_cls)
 
@@ -232,33 +274,24 @@ class Mapper[C: Component]:
 
             self.collector.add_insert(
                 attribute_component,
-                obj.attributes,
+                attributes,
                 receipts=receipts,
             )
 
-            for key in keys:
-                collation_data = obj.collate(key)
-
+            for key, key_collation_data in collation_data.items():
                 # if method either returned no data or isn't registered, ignore
-                if collation_data is None:
+                if not key_collation_data:
                     continue
 
-                _, groups = obj.key_registry.get(key, (None, []))
-                for group in groups:
+                for group, group_collation_data in key_collation_data.items():
                     collation_component = self.get_coll_comp(_cls, group=group)
 
                     if collation_component is None:
                         continue
 
-                    # gather connective data for collation components
-                    connective_data = obj.collation_attributes(key, group)
-
                     self.collector.add_insert(
                         collation_component,
-                        {
-                            **connective_data,
-                            **collation_data, 
-                        },
+                        group_collation_data,
                         receipts=receipts,
                     )
 
diff --git a/tests/test_co3.py b/tests/test_co3.py
index cc4bfd3..819402d 100644
--- a/tests/test_co3.py
+++ b/tests/test_co3.py
@@ -18,11 +18,8 @@ def test_co3_registry():
     assert set(tomato.key_registry.get(None,{}).keys()) == set(keys_to_groups.get(None,[]))
 
     # check against `registry`, should map keys to all groups
-    for key, group_obj in tomato.key_registry.items():
-        if key is None: continue
-
-        _, groups = group_obj
-        assert keys_to_groups.get(key) == groups
+    for key, group_dict in tomato.key_registry.items():
+        assert keys_to_groups.get(key) == list(group_dict.keys())
 
 def test_co3_attributes():
     assert tomato.attributes is not None
@@ -39,4 +36,4 @@ def test_co3_collate():
     for group, keys in tomato.group_registry.items():
         for key in keys:
             if key is None: continue
-            assert tomato.collate(key) is not None
+            assert tomato.collate(key, group=group) is not None