first attempt

This commit is contained in:
Tom 2025-03-27 18:29:50 +00:00
parent 7b36a76154
commit 1259ff08b6

View File

@ -1,12 +1,13 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import replace
from dataclasses import dataclass, replace
from enum import Enum
# Prevent circular imports while allowing the type checker to know what Qube is
from typing import TYPE_CHECKING, Iterable
from typing import TYPE_CHECKING, Any, Iterable
import numpy as np
from frozendict import frozendict
from .node_types import NodeData
@ -23,28 +24,70 @@ class SetOperation(Enum):
SYMMETRIC_DIFFERENCE = (1, 0, 1)
@dataclass(eq=True, frozen=True)
class ValuesMetadata:
values: ValueGroup
metadata: dict[str, np.ndarray]
def QEnum_intersection(
A: ValuesMetadata,
B: ValuesMetadata,
) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
intersection: dict[Any, int] = {}
just_A: dict[Any, int] = {}
just_B: dict[Any, int] = {val: i for i, val in enumerate(B.values)}
for index_a, val_A in enumerate(A.values):
if val_A in B.values:
index_b = just_B.pop(val_A)
intersection[val_A] = (
index_b # We throw away any overlapping metadata from B
)
else:
just_A[val_A] = index_a
intersection_out = ValuesMetadata(
values=QEnum(list(intersection.keys())),
metadata={
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
},
)
just_A_out = ValuesMetadata(
values=QEnum(list(just_A.keys())),
metadata={
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
},
)
just_B_out = ValuesMetadata(
values=QEnum(list(just_B.keys())),
metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()},
)
return just_A_out, intersection_out, just_B_out
def node_intersection(
A: ValueGroup, B: ValueGroup
) -> tuple[ValueGroup, ValueGroup, ValueGroup]:
if isinstance(A, QEnum) and isinstance(B, QEnum):
set_A, set_B = set(A), set(B)
intersection = set_A & set_B
just_A = set_A - intersection
just_B = set_B - intersection
return QEnum(just_A), QEnum(intersection), QEnum(just_B)
A: ValuesMetadata,
B: ValuesMetadata,
) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
return QEnum_intersection(A, B)
if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup):
return A, WildcardGroup(), B
return A, ValuesMetadata(WildcardGroup(), {}), B
# If A is a wildcard matcher then the intersection is everything
# just_A is still *
# just_B is empty
if isinstance(A, WildcardGroup):
return A, B, QEnum([])
return A, B, ValuesMetadata(QEnum([]), {})
# The reverse if B is a wildcard
if isinstance(B, WildcardGroup):
return QEnum([]), A, B
return ValuesMetadata(QEnum([]), {}), A, B
raise NotImplementedError(
f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented"
@ -96,7 +139,7 @@ def _operation(
# Iterate over all pairs (node_A, node_B)
values = {}
for node in A + B:
values[node] = node.values
values[node] = ValuesMetadata(node.values, node.metadata)
for node_a in A:
for node_b in B:
@ -114,10 +157,20 @@ def _operation(
if keep_intersection:
if intersection:
new_node_a = replace(
node_a, data=replace(node_a.data, values=intersection)
node_a,
data=replace(
node_a.data,
values=intersection.values,
metadata=intersection.metadata,
),
)
new_node_b = replace(
node_b, data=replace(node_b.data, values=intersection)
node_b,
data=replace(
node_b.data,
values=intersection.values,
metadata=intersection.metadata,
),
)
yield operation(new_node_a, new_node_b, operation_type, node_type)
@ -125,11 +178,21 @@ def _operation(
if keep_just_A:
for node in A:
if values[node]:
yield node_type.make(key, values[node], node.children)
yield node_type.make(
key,
children=node.children,
values=values[node].values,
metadata=values[node].metadata,
)
if keep_just_B:
for node in B:
if values[node]:
yield node_type.make(key, values[node], node.children)
yield node_type.make(
key,
children=node.children,
values=values[node].values,
metadata=values[node].metadata,
)
def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
@ -137,7 +200,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
Helper method tht only compresses a set of nodes, and doesn't do it recursively.
Used in Qubed.compress but also to maintain compression in the set operations above.
"""
# Now take the set of new children and see if any have identical key, metadata and children
# Take the set of new children and see if any have identical key, metadata and children
# the values may different and will be collapsed into a single node
identical_children = defaultdict(set)
for child in children: