From 4e777f295d72e65ad529fd2671852cbc93a87aa7 Mon Sep 17 00:00:00 2001
From: Tom <thomas.hodson@ecmwf.int>
Date: Fri, 28 Mar 2025 09:27:56 +0000
Subject: [PATCH] More work on metadata

---
 src/python/qubed/Qube.py           |  2 --
 src/python/qubed/set_operations.py | 54 +++++++++++++++++++-----------
 src/python/qubed/value_types.py    |  2 +-
 3 files changed, 35 insertions(+), 23 deletions(-)

diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py
index 0ff9966..6e5547d 100644
--- a/src/python/qubed/Qube.py
+++ b/src/python/qubed/Qube.py
@@ -291,8 +291,6 @@ class Qube:
             return
 
         for index, value in enumerate(self.values):
-            # print(self.key, index, indices, value)
-            # print({k: np.shape(v) for k, v in self.metadata.items()})
             indexed_metadata = {
                 k: vs[indices + (index,)] for k, vs in self.metadata.items()
             }
diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py
index f337b41..51cb85e 100644
--- a/src/python/qubed/set_operations.py
+++ b/src/python/qubed/set_operations.py
@@ -40,9 +40,10 @@ def QEnum_intersection(
 
     for index_a, val_A in enumerate(A.values):
         if val_A in B.values:
-            index_b = just_B.pop(val_A)
+            # print(f"{val_A} in both")
+            just_B.pop(val_A)
             intersection[val_A] = (
-                index_b  # We throw away any overlapping metadata from B
+                index_a  # We throw away any overlapping metadata from B
             )
         else:
             just_A[val_A] = index_a
@@ -56,9 +57,7 @@ def QEnum_intersection(
 
     just_A_out = ValuesMetadata(
         values=QEnum(list(just_A.keys())),
-        metadata={
-            k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
-        },
+        metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()},
     )
 
     just_B_out = ValuesMetadata(
@@ -76,21 +75,21 @@ def node_intersection(
     if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
         return QEnum_intersection(A, B)
 
-    if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup):
+    if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
         return A, ValuesMetadata(WildcardGroup(), {}), B
 
     # If A is a wildcard matcher then the intersection is everything
     # just_A is still *
     # just_B is empty
-    if isinstance(A, WildcardGroup):
+    if isinstance(A.values, WildcardGroup):
         return A, B, ValuesMetadata(QEnum([]), {})
 
     # The reverse if B is a wildcard
-    if isinstance(B, WildcardGroup):
+    if isinstance(B.values, WildcardGroup):
         return ValuesMetadata(QEnum([]), {}), A, B
 
     raise NotImplementedError(
-        f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented"
+        f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
     )
 
 
@@ -155,7 +154,7 @@ def _operation(
             values[node_b] = just_b
 
             if keep_intersection:
-                if intersection:
+                if intersection.values:
                     new_node_a = replace(
                         node_a,
                         data=replace(
@@ -177,7 +176,7 @@ def _operation(
     # Now we've removed all the intersections we can yield the just_A and just_B parts if needed
     if keep_just_A:
         for node in A:
-            if values[node]:
+            if values[node].values:
                 yield node_type.make(
                     key,
                     children=node.children,
@@ -186,7 +185,7 @@ def _operation(
                 )
     if keep_just_B:
         for node in B:
-            if values[node]:
+            if values[node].values:
                 yield node_type.make(
                     key,
                     children=node.children,
@@ -212,26 +211,41 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
     new_children = []
     for child_set in identical_children.values():
         if len(child_set) > 1:
-            child_list = list(child_set)
-            node_type = type(child_list[0])
-            key = child_list[0].key
+            child_set = list(child_set)
+            example = child_set[0]
+            node_type = type(example)
+            key = child_set[0].key
 
             # Compress the children into a single node
             assert all(isinstance(child.data.values, QEnum) for child in child_set), (
                 "All children must have QEnum values"
             )
 
-            node_data = NodeData(
-                key=str(key),
-                metadata=frozendict(),  # Todo: Implement metadata compression
-                values=QEnum((v for child in child_set for v in child.data.values)),
+            metadata_groups = {
+                k: [child.metadata[k] for child in child_set]
+                for k in example.metadata.keys()
+            }
+            metadata: dict[str, np.ndarray] = frozendict(
+                {
+                    k: np.concatenate(metadata_group, axis=-1)
+                    for k, metadata_group in metadata_groups.items()
+                }
             )
-            new_child = node_type(data=node_data, children=child_list[0].children)
+
+            node_data = NodeData(
+                key=key,
+                metadata=metadata,
+                values=QEnum(
+                    (v for child in child_set for v in child.data.values.values)
+                ),
+            )
+            new_child = node_type(data=node_data, children=child_set[0].children)
         else:
             # If the group is size one just keep it
             new_child = child_set.pop()
 
         new_children.append(new_child)
+
     return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min()))))
 
 
diff --git a/src/python/qubed/value_types.py b/src/python/qubed/value_types.py
index eb73b67..7eb1fd5 100644
--- a/src/python/qubed/value_types.py
+++ b/src/python/qubed/value_types.py
@@ -71,7 +71,7 @@ class QEnum(ValueGroup):
     values: EnumValuesType
 
     def __init__(self, obj):
-        object.__setattr__(self, "values", frozenset(obj))
+        object.__setattr__(self, "values", tuple(sorted(obj)))
 
     def __post_init__(self):
         assert isinstance(self.values, tuple)