diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index 0ff9966..6e5547d 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -291,8 +291,6 @@ class Qube: return for index, value in enumerate(self.values): - # print(self.key, index, indices, value) - # print({k: np.shape(v) for k, v in self.metadata.items()}) indexed_metadata = { k: vs[indices + (index,)] for k, vs in self.metadata.items() } diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index f337b41..51cb85e 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -40,9 +40,10 @@ def QEnum_intersection( for index_a, val_A in enumerate(A.values): if val_A in B.values: - index_b = just_B.pop(val_A) + # print(f"{val_A} in both") + just_B.pop(val_A) intersection[val_A] = ( - index_b # We throw away any overlapping metadata from B + index_a # We throw away any overlapping metadata from B ) else: just_A[val_A] = index_a @@ -56,9 +57,7 @@ def QEnum_intersection( just_A_out = ValuesMetadata( values=QEnum(list(just_A.keys())), - metadata={ - k: v[..., tuple(intersection.values())] for k, v in A.metadata.items() - }, + metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()}, ) just_B_out = ValuesMetadata( @@ -76,21 +75,21 @@ def node_intersection( if isinstance(A.values, QEnum) and isinstance(B.values, QEnum): return QEnum_intersection(A, B) - if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup): + if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup): return A, ValuesMetadata(WildcardGroup(), {}), B # If A is a wildcard matcher then the intersection is everything # just_A is still * # just_B is empty - if isinstance(A, WildcardGroup): + if isinstance(A.values, WildcardGroup): return A, B, ValuesMetadata(QEnum([]), {}) # The reverse if B is a wildcard - if isinstance(B, WildcardGroup): + if isinstance(B.values, WildcardGroup): return ValuesMetadata(QEnum([]), {}), A, B raise NotImplementedError( - f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented" + f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented" ) @@ -155,7 +154,7 @@ def _operation( values[node_b] = just_b if keep_intersection: - if intersection: + if intersection.values: new_node_a = replace( node_a, data=replace( @@ -177,7 +176,7 @@ def _operation( # Now we've removed all the intersections we can yield the just_A and just_B parts if needed if keep_just_A: for node in A: - if values[node]: + if values[node].values: yield node_type.make( key, children=node.children, @@ -186,7 +185,7 @@ def _operation( ) if keep_just_B: for node in B: - if values[node]: + if values[node].values: yield node_type.make( key, children=node.children, @@ -212,26 +211,41 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: new_children = [] for child_set in identical_children.values(): if len(child_set) > 1: - child_list = list(child_set) - node_type = type(child_list[0]) - key = child_list[0].key + child_set = list(child_set) + example = child_set[0] + node_type = type(example) + key = child_set[0].key # Compress the children into a single node assert all(isinstance(child.data.values, QEnum) for child in child_set), ( "All children must have QEnum values" ) - node_data = NodeData( - key=str(key), - metadata=frozendict(), # Todo: Implement metadata compression - values=QEnum((v for child in child_set for v in child.data.values)), + metadata_groups = { + k: [child.metadata[k] for child in child_set] + for k in example.metadata.keys() + } + metadata: dict[str, np.ndarray] = frozendict( + { + k: np.concatenate(metadata_group, axis=-1) + for k, metadata_group in metadata_groups.items() + } ) - new_child = node_type(data=node_data, children=child_list[0].children) + + node_data = NodeData( + key=key, + metadata=metadata, + values=QEnum( + (v for child in child_set for v in child.data.values.values) + ), + ) + new_child = node_type(data=node_data, children=child_set[0].children) else: # If the group is size one just keep it new_child = child_set.pop() new_children.append(new_child) + return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min())))) diff --git a/src/python/qubed/value_types.py b/src/python/qubed/value_types.py index eb73b67..7eb1fd5 100644 --- a/src/python/qubed/value_types.py +++ b/src/python/qubed/value_types.py @@ -71,7 +71,7 @@ class QEnum(ValueGroup): values: EnumValuesType def __init__(self, obj): - object.__setattr__(self, "values", frozenset(obj)) + object.__setattr__(self, "values", tuple(sorted(obj))) def __post_init__(self): assert isinstance(self.values, tuple)