first attempt

This commit is contained in:
Tom 2025-03-27 18:29:50 +00:00
parent 7b36a76154
commit 1259ff08b6

View File

@ -1,12 +1,13 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from dataclasses import replace from dataclasses import dataclass, replace
from enum import Enum from enum import Enum
# Prevent circular imports while allowing the type checker to know what Qube is # Prevent circular imports while allowing the type checker to know what Qube is
from typing import TYPE_CHECKING, Iterable from typing import TYPE_CHECKING, Any, Iterable
import numpy as np
from frozendict import frozendict from frozendict import frozendict
from .node_types import NodeData from .node_types import NodeData
@ -23,28 +24,70 @@ class SetOperation(Enum):
SYMMETRIC_DIFFERENCE = (1, 0, 1) SYMMETRIC_DIFFERENCE = (1, 0, 1)
@dataclass(eq=True, frozen=True)
class ValuesMetadata:
values: ValueGroup
metadata: dict[str, np.ndarray]
def QEnum_intersection(
A: ValuesMetadata,
B: ValuesMetadata,
) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
intersection: dict[Any, int] = {}
just_A: dict[Any, int] = {}
just_B: dict[Any, int] = {val: i for i, val in enumerate(B.values)}
for index_a, val_A in enumerate(A.values):
if val_A in B.values:
index_b = just_B.pop(val_A)
intersection[val_A] = (
index_b # We throw away any overlapping metadata from B
)
else:
just_A[val_A] = index_a
intersection_out = ValuesMetadata(
values=QEnum(list(intersection.keys())),
metadata={
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
},
)
just_A_out = ValuesMetadata(
values=QEnum(list(just_A.keys())),
metadata={
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
},
)
just_B_out = ValuesMetadata(
values=QEnum(list(just_B.keys())),
metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()},
)
return just_A_out, intersection_out, just_B_out
def node_intersection( def node_intersection(
A: ValueGroup, B: ValueGroup A: ValuesMetadata,
) -> tuple[ValueGroup, ValueGroup, ValueGroup]: B: ValuesMetadata,
if isinstance(A, QEnum) and isinstance(B, QEnum): ) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
set_A, set_B = set(A), set(B) if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
intersection = set_A & set_B return QEnum_intersection(A, B)
just_A = set_A - intersection
just_B = set_B - intersection
return QEnum(just_A), QEnum(intersection), QEnum(just_B)
if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup): if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup):
return A, WildcardGroup(), B return A, ValuesMetadata(WildcardGroup(), {}), B
# If A is a wildcard matcher then the intersection is everything # If A is a wildcard matcher then the intersection is everything
# just_A is still * # just_A is still *
# just_B is empty # just_B is empty
if isinstance(A, WildcardGroup): if isinstance(A, WildcardGroup):
return A, B, QEnum([]) return A, B, ValuesMetadata(QEnum([]), {})
# The reverse if B is a wildcard # The reverse if B is a wildcard
if isinstance(B, WildcardGroup): if isinstance(B, WildcardGroup):
return QEnum([]), A, B return ValuesMetadata(QEnum([]), {}), A, B
raise NotImplementedError( raise NotImplementedError(
f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented" f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented"
@ -96,7 +139,7 @@ def _operation(
# Iterate over all pairs (node_A, node_B) # Iterate over all pairs (node_A, node_B)
values = {} values = {}
for node in A + B: for node in A + B:
values[node] = node.values values[node] = ValuesMetadata(node.values, node.metadata)
for node_a in A: for node_a in A:
for node_b in B: for node_b in B:
@ -114,10 +157,20 @@ def _operation(
if keep_intersection: if keep_intersection:
if intersection: if intersection:
new_node_a = replace( new_node_a = replace(
node_a, data=replace(node_a.data, values=intersection) node_a,
data=replace(
node_a.data,
values=intersection.values,
metadata=intersection.metadata,
),
) )
new_node_b = replace( new_node_b = replace(
node_b, data=replace(node_b.data, values=intersection) node_b,
data=replace(
node_b.data,
values=intersection.values,
metadata=intersection.metadata,
),
) )
yield operation(new_node_a, new_node_b, operation_type, node_type) yield operation(new_node_a, new_node_b, operation_type, node_type)
@ -125,11 +178,21 @@ def _operation(
if keep_just_A: if keep_just_A:
for node in A: for node in A:
if values[node]: if values[node]:
yield node_type.make(key, values[node], node.children) yield node_type.make(
key,
children=node.children,
values=values[node].values,
metadata=values[node].metadata,
)
if keep_just_B: if keep_just_B:
for node in B: for node in B:
if values[node]: if values[node]:
yield node_type.make(key, values[node], node.children) yield node_type.make(
key,
children=node.children,
values=values[node].values,
metadata=values[node].metadata,
)
def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
@ -137,7 +200,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
Helper method tht only compresses a set of nodes, and doesn't do it recursively. Helper method tht only compresses a set of nodes, and doesn't do it recursively.
Used in Qubed.compress but also to maintain compression in the set operations above. Used in Qubed.compress but also to maintain compression in the set operations above.
""" """
# Now take the set of new children and see if any have identical key, metadata and children # Take the set of new children and see if any have identical key, metadata and children
# the values may different and will be collapsed into a single node # the values may different and will be collapsed into a single node
identical_children = defaultdict(set) identical_children = defaultdict(set)
for child in children: for child in children: