diff --git a/co3/co3.py b/co3/co3.py index 62baa64..5ccd5f4 100644 --- a/co3/co3.py +++ b/co3/co3.py @@ -150,27 +150,17 @@ class FormatRegistryMeta(type): Metaclass handling collation registry at the class level. ''' def __new__(cls, name, bases, attrs): - key_registry = {} - group_registry = defaultdict(list) + key_registry = defaultdict(dict) + group_registry = defaultdict(set) def register_action(method): nonlocal key_registry, group_registry if hasattr(method, '_collation_data'): key, groups = method._collation_data - - if key is None: - # only add a "None" entry if there is _some_ implicit group - if None not in key_registry: - key_registry[None] = {} - - # only a single group possible here - key_registry[None][groups[0]] = method - else: - key_registry[key] = (method, groups) - for group in groups: - group_registry[group].append(key) + key_registry[key][group] = method + group_registry[group].add(key) # add registered superclass methods; iterate over bases (usually just one), then # that base's chain down (reversed), then methods from each subclass @@ -244,6 +234,14 @@ class CO3(metaclass=FormatRegistryMeta): return {} def collate(self, key, group=None, *args, **kwargs): + ''' + Note: + This method is sensitive to group specification. By default, the provided key + will be checked against the default ``None`` group, even if that key is only + attached to non-default groups. Collation actions are unique on key-group + pairs, so more specificity is generally required to correctly execute desired + actions (otherwise, rely more heavily on the default group). + ''' if key is None: return None @@ -258,13 +256,19 @@ class CO3(metaclass=FormatRegistryMeta): method = self.key_registry[None].get(group) if method is None: logger.debug( - f'Collation key "{key}" not registered and group {group} not implicit' + f'Collation key "{key}" not registered and group "{group}" not implicit' ) return None return method(self, key, *args, **kwargs) else: - method = self.key_registry[key][0] + method = self.key_registry[key].get(group) + if method is None: + logger.debug( + f'Collation key "{key}" registered, but group "{group}" is not available' + ) + return None + return method(self, *args, **kwargs) diff --git a/co3/mapper.py b/co3/mapper.py index 98cb02d..46723f8 100644 --- a/co3/mapper.py +++ b/co3/mapper.py @@ -213,16 +213,58 @@ class Mapper[C: Component]: Parameters: obj: CO3 instance to collect from keys: keys for actions to collect from - group: action group names to run all actions for + group: group contexts for the keys to collect from. If None, explicit group + contexts registered for the keys will be inferred (but implicit groups + will not be detected). - Returns: dict with keys and values relevant for associated SQLite tables + Returns: collector receipts for staged inserts ''' # default is to have no actions if keys is None: keys = [] #keys = list(obj.key_registry.keys()) + collation_data = defaultdict(dict) + for key in keys: + # keys must be defined + if key is None: + continue + + # if groups not specified, dynamically grab those explicitly attached groups + # for each key + group_dict = {} + if groups is None: + group_dict = obj.key_registry.get(key, {}) + else: + for group in groups: + group_dict[group] = obj.key_registry.get(key, {}).get(group) + + # method regroup: under key, index by method and run once per + method_groups = defaultdict(list) + for group_name, group_method in group_dict.items(): + method_groups[group_method].append(group_name) + + # collate for method equivalence classes; only need on representative group to + # pass to CO3.collate to call the method + key_collation_data = {} + for collation_method, collation_groups in method_groups.items(): + key_method_collation_data = obj.collate(key, group=collation_groups[0]) + + for collation_group in collation_groups: + # gather connective data for collation components + # -> we do this here as it's obj dependent + connective_data = obj.collation_attributes(key, collation_group) + + key_collation_data[collation_group] = { + **connective_data, + **key_method_collation_data, + } + + collation_data[key] = key_collation_data + receipts = [] + attributes = obj.attributes + for _cls in reversed(obj.__class__.__mro__[:-2]): attribute_component = self.get_attr_comp(_cls) @@ -232,33 +274,24 @@ class Mapper[C: Component]: self.collector.add_insert( attribute_component, - obj.attributes, + attributes, receipts=receipts, ) - for key in keys: - collation_data = obj.collate(key) - + for key, key_collation_data in collation_data.items(): # if method either returned no data or isn't registered, ignore - if collation_data is None: + if not key_collation_data: continue - _, groups = obj.key_registry.get(key, (None, [])) - for group in groups: + for group, group_collation_data in key_collation_data.items(): collation_component = self.get_coll_comp(_cls, group=group) if collation_component is None: continue - # gather connective data for collation components - connective_data = obj.collation_attributes(key, group) - self.collector.add_insert( collation_component, - { - **connective_data, - **collation_data, - }, + group_collation_data, receipts=receipts, ) diff --git a/tests/test_co3.py b/tests/test_co3.py index cc4bfd3..819402d 100644 --- a/tests/test_co3.py +++ b/tests/test_co3.py @@ -18,11 +18,8 @@ def test_co3_registry(): assert set(tomato.key_registry.get(None,{}).keys()) == set(keys_to_groups.get(None,[])) # check against `registry`, should map keys to all groups - for key, group_obj in tomato.key_registry.items(): - if key is None: continue - - _, groups = group_obj - assert keys_to_groups.get(key) == groups + for key, group_dict in tomato.key_registry.items(): + assert keys_to_groups.get(key) == list(group_dict.keys()) def test_co3_attributes(): assert tomato.attributes is not None @@ -39,4 +36,4 @@ def test_co3_collate(): for group, keys in tomato.group_registry.items(): for key in keys: if key is None: continue - assert tomato.collate(key) is not None + assert tomato.collate(key, group=group) is not None