More work on metadata

This commit is contained in:
Tom 2025-03-28 09:27:56 +00:00
parent 1259ff08b6
commit 4e777f295d
3 changed files with 35 additions and 23 deletions

View File

@ -291,8 +291,6 @@ class Qube:
return
for index, value in enumerate(self.values):
# print(self.key, index, indices, value)
# print({k: np.shape(v) for k, v in self.metadata.items()})
indexed_metadata = {
k: vs[indices + (index,)] for k, vs in self.metadata.items()
}

View File

@ -40,9 +40,10 @@ def QEnum_intersection(
for index_a, val_A in enumerate(A.values):
if val_A in B.values:
index_b = just_B.pop(val_A)
# print(f"{val_A} in both")
just_B.pop(val_A)
intersection[val_A] = (
index_b # We throw away any overlapping metadata from B
index_a # We throw away any overlapping metadata from B
)
else:
just_A[val_A] = index_a
@ -56,9 +57,7 @@ def QEnum_intersection(
just_A_out = ValuesMetadata(
values=QEnum(list(just_A.keys())),
metadata={
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
},
metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()},
)
just_B_out = ValuesMetadata(
@ -76,21 +75,21 @@ def node_intersection(
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
return QEnum_intersection(A, B)
if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup):
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
return A, ValuesMetadata(WildcardGroup(), {}), B
# If A is a wildcard matcher then the intersection is everything
# just_A is still *
# just_B is empty
if isinstance(A, WildcardGroup):
if isinstance(A.values, WildcardGroup):
return A, B, ValuesMetadata(QEnum([]), {})
# The reverse if B is a wildcard
if isinstance(B, WildcardGroup):
if isinstance(B.values, WildcardGroup):
return ValuesMetadata(QEnum([]), {}), A, B
raise NotImplementedError(
f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented"
f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
)
@ -155,7 +154,7 @@ def _operation(
values[node_b] = just_b
if keep_intersection:
if intersection:
if intersection.values:
new_node_a = replace(
node_a,
data=replace(
@ -177,7 +176,7 @@ def _operation(
# Now we've removed all the intersections we can yield the just_A and just_B parts if needed
if keep_just_A:
for node in A:
if values[node]:
if values[node].values:
yield node_type.make(
key,
children=node.children,
@ -186,7 +185,7 @@ def _operation(
)
if keep_just_B:
for node in B:
if values[node]:
if values[node].values:
yield node_type.make(
key,
children=node.children,
@ -212,26 +211,41 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
new_children = []
for child_set in identical_children.values():
if len(child_set) > 1:
child_list = list(child_set)
node_type = type(child_list[0])
key = child_list[0].key
child_set = list(child_set)
example = child_set[0]
node_type = type(example)
key = child_set[0].key
# Compress the children into a single node
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
"All children must have QEnum values"
)
node_data = NodeData(
key=str(key),
metadata=frozendict(), # Todo: Implement metadata compression
values=QEnum((v for child in child_set for v in child.data.values)),
metadata_groups = {
k: [child.metadata[k] for child in child_set]
for k in example.metadata.keys()
}
metadata: dict[str, np.ndarray] = frozendict(
{
k: np.concatenate(metadata_group, axis=-1)
for k, metadata_group in metadata_groups.items()
}
)
new_child = node_type(data=node_data, children=child_list[0].children)
node_data = NodeData(
key=key,
metadata=metadata,
values=QEnum(
(v for child in child_set for v in child.data.values.values)
),
)
new_child = node_type(data=node_data, children=child_set[0].children)
else:
# If the group is size one just keep it
new_child = child_set.pop()
new_children.append(new_child)
return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min()))))

View File

@ -71,7 +71,7 @@ class QEnum(ValueGroup):
values: EnumValuesType
def __init__(self, obj):
object.__setattr__(self, "values", frozenset(obj))
object.__setattr__(self, "values", tuple(sorted(obj)))
def __post_init__(self):
assert isinstance(self.values, tuple)