From 1259ff08b651f13416e68e5797b3a07681cfd063 Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 27 Mar 2025 18:29:50 +0000 Subject: [PATCH] first attempt --- src/python/qubed/set_operations.py | 101 +++++++++++++++++++++++------ 1 file changed, 82 insertions(+), 19 deletions(-) diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index 859b128..f337b41 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -1,12 +1,13 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import replace +from dataclasses import dataclass, replace from enum import Enum # Prevent circular imports while allowing the type checker to know what Qube is -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Any, Iterable +import numpy as np from frozendict import frozendict from .node_types import NodeData @@ -23,28 +24,70 @@ class SetOperation(Enum): SYMMETRIC_DIFFERENCE = (1, 0, 1) +@dataclass(eq=True, frozen=True) +class ValuesMetadata: + values: ValueGroup + metadata: dict[str, np.ndarray] + + +def QEnum_intersection( + A: ValuesMetadata, + B: ValuesMetadata, +) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]: + intersection: dict[Any, int] = {} + just_A: dict[Any, int] = {} + just_B: dict[Any, int] = {val: i for i, val in enumerate(B.values)} + + for index_a, val_A in enumerate(A.values): + if val_A in B.values: + index_b = just_B.pop(val_A) + intersection[val_A] = ( + index_b # We throw away any overlapping metadata from B + ) + else: + just_A[val_A] = index_a + + intersection_out = ValuesMetadata( + values=QEnum(list(intersection.keys())), + metadata={ + k: v[..., tuple(intersection.values())] for k, v in A.metadata.items() + }, + ) + + just_A_out = ValuesMetadata( + values=QEnum(list(just_A.keys())), + metadata={ + k: v[..., tuple(intersection.values())] for k, v in A.metadata.items() + }, + ) + + just_B_out = ValuesMetadata( + values=QEnum(list(just_B.keys())), + metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()}, + ) + + return just_A_out, intersection_out, just_B_out + + def node_intersection( - A: ValueGroup, B: ValueGroup -) -> tuple[ValueGroup, ValueGroup, ValueGroup]: - if isinstance(A, QEnum) and isinstance(B, QEnum): - set_A, set_B = set(A), set(B) - intersection = set_A & set_B - just_A = set_A - intersection - just_B = set_B - intersection - return QEnum(just_A), QEnum(intersection), QEnum(just_B) + A: ValuesMetadata, + B: ValuesMetadata, +) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]: + if isinstance(A.values, QEnum) and isinstance(B.values, QEnum): + return QEnum_intersection(A, B) if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup): - return A, WildcardGroup(), B + return A, ValuesMetadata(WildcardGroup(), {}), B # If A is a wildcard matcher then the intersection is everything # just_A is still * # just_B is empty if isinstance(A, WildcardGroup): - return A, B, QEnum([]) + return A, B, ValuesMetadata(QEnum([]), {}) # The reverse if B is a wildcard if isinstance(B, WildcardGroup): - return QEnum([]), A, B + return ValuesMetadata(QEnum([]), {}), A, B raise NotImplementedError( f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented" @@ -96,7 +139,7 @@ def _operation( # Iterate over all pairs (node_A, node_B) values = {} for node in A + B: - values[node] = node.values + values[node] = ValuesMetadata(node.values, node.metadata) for node_a in A: for node_b in B: @@ -114,10 +157,20 @@ def _operation( if keep_intersection: if intersection: new_node_a = replace( - node_a, data=replace(node_a.data, values=intersection) + node_a, + data=replace( + node_a.data, + values=intersection.values, + metadata=intersection.metadata, + ), ) new_node_b = replace( - node_b, data=replace(node_b.data, values=intersection) + node_b, + data=replace( + node_b.data, + values=intersection.values, + metadata=intersection.metadata, + ), ) yield operation(new_node_a, new_node_b, operation_type, node_type) @@ -125,11 +178,21 @@ def _operation( if keep_just_A: for node in A: if values[node]: - yield node_type.make(key, values[node], node.children) + yield node_type.make( + key, + children=node.children, + values=values[node].values, + metadata=values[node].metadata, + ) if keep_just_B: for node in B: if values[node]: - yield node_type.make(key, values[node], node.children) + yield node_type.make( + key, + children=node.children, + values=values[node].values, + metadata=values[node].metadata, + ) def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: @@ -137,7 +200,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: Helper method tht only compresses a set of nodes, and doesn't do it recursively. Used in Qubed.compress but also to maintain compression in the set operations above. """ - # Now take the set of new children and see if any have identical key, metadata and children + # Take the set of new children and see if any have identical key, metadata and children # the values may different and will be collapsed into a single node identical_children = defaultdict(set) for child in children: