first attempt
This commit is contained in:
parent
7b36a76154
commit
1259ff08b6
@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
|
||||
# Prevent circular imports while allowing the type checker to know what Qube is
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
from frozendict import frozendict
|
||||
|
||||
from .node_types import NodeData
|
||||
@ -23,28 +24,70 @@ class SetOperation(Enum):
|
||||
SYMMETRIC_DIFFERENCE = (1, 0, 1)
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class ValuesMetadata:
|
||||
values: ValueGroup
|
||||
metadata: dict[str, np.ndarray]
|
||||
|
||||
|
||||
def QEnum_intersection(
|
||||
A: ValuesMetadata,
|
||||
B: ValuesMetadata,
|
||||
) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
|
||||
intersection: dict[Any, int] = {}
|
||||
just_A: dict[Any, int] = {}
|
||||
just_B: dict[Any, int] = {val: i for i, val in enumerate(B.values)}
|
||||
|
||||
for index_a, val_A in enumerate(A.values):
|
||||
if val_A in B.values:
|
||||
index_b = just_B.pop(val_A)
|
||||
intersection[val_A] = (
|
||||
index_b # We throw away any overlapping metadata from B
|
||||
)
|
||||
else:
|
||||
just_A[val_A] = index_a
|
||||
|
||||
intersection_out = ValuesMetadata(
|
||||
values=QEnum(list(intersection.keys())),
|
||||
metadata={
|
||||
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
|
||||
},
|
||||
)
|
||||
|
||||
just_A_out = ValuesMetadata(
|
||||
values=QEnum(list(just_A.keys())),
|
||||
metadata={
|
||||
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
|
||||
},
|
||||
)
|
||||
|
||||
just_B_out = ValuesMetadata(
|
||||
values=QEnum(list(just_B.keys())),
|
||||
metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()},
|
||||
)
|
||||
|
||||
return just_A_out, intersection_out, just_B_out
|
||||
|
||||
|
||||
def node_intersection(
|
||||
A: ValueGroup, B: ValueGroup
|
||||
) -> tuple[ValueGroup, ValueGroup, ValueGroup]:
|
||||
if isinstance(A, QEnum) and isinstance(B, QEnum):
|
||||
set_A, set_B = set(A), set(B)
|
||||
intersection = set_A & set_B
|
||||
just_A = set_A - intersection
|
||||
just_B = set_B - intersection
|
||||
return QEnum(just_A), QEnum(intersection), QEnum(just_B)
|
||||
A: ValuesMetadata,
|
||||
B: ValuesMetadata,
|
||||
) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
|
||||
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
|
||||
return QEnum_intersection(A, B)
|
||||
|
||||
if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup):
|
||||
return A, WildcardGroup(), B
|
||||
return A, ValuesMetadata(WildcardGroup(), {}), B
|
||||
|
||||
# If A is a wildcard matcher then the intersection is everything
|
||||
# just_A is still *
|
||||
# just_B is empty
|
||||
if isinstance(A, WildcardGroup):
|
||||
return A, B, QEnum([])
|
||||
return A, B, ValuesMetadata(QEnum([]), {})
|
||||
|
||||
# The reverse if B is a wildcard
|
||||
if isinstance(B, WildcardGroup):
|
||||
return QEnum([]), A, B
|
||||
return ValuesMetadata(QEnum([]), {}), A, B
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented"
|
||||
@ -96,7 +139,7 @@ def _operation(
|
||||
# Iterate over all pairs (node_A, node_B)
|
||||
values = {}
|
||||
for node in A + B:
|
||||
values[node] = node.values
|
||||
values[node] = ValuesMetadata(node.values, node.metadata)
|
||||
|
||||
for node_a in A:
|
||||
for node_b in B:
|
||||
@ -114,10 +157,20 @@ def _operation(
|
||||
if keep_intersection:
|
||||
if intersection:
|
||||
new_node_a = replace(
|
||||
node_a, data=replace(node_a.data, values=intersection)
|
||||
node_a,
|
||||
data=replace(
|
||||
node_a.data,
|
||||
values=intersection.values,
|
||||
metadata=intersection.metadata,
|
||||
),
|
||||
)
|
||||
new_node_b = replace(
|
||||
node_b, data=replace(node_b.data, values=intersection)
|
||||
node_b,
|
||||
data=replace(
|
||||
node_b.data,
|
||||
values=intersection.values,
|
||||
metadata=intersection.metadata,
|
||||
),
|
||||
)
|
||||
yield operation(new_node_a, new_node_b, operation_type, node_type)
|
||||
|
||||
@ -125,11 +178,21 @@ def _operation(
|
||||
if keep_just_A:
|
||||
for node in A:
|
||||
if values[node]:
|
||||
yield node_type.make(key, values[node], node.children)
|
||||
yield node_type.make(
|
||||
key,
|
||||
children=node.children,
|
||||
values=values[node].values,
|
||||
metadata=values[node].metadata,
|
||||
)
|
||||
if keep_just_B:
|
||||
for node in B:
|
||||
if values[node]:
|
||||
yield node_type.make(key, values[node], node.children)
|
||||
yield node_type.make(
|
||||
key,
|
||||
children=node.children,
|
||||
values=values[node].values,
|
||||
metadata=values[node].metadata,
|
||||
)
|
||||
|
||||
|
||||
def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
||||
@ -137,7 +200,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
||||
Helper method tht only compresses a set of nodes, and doesn't do it recursively.
|
||||
Used in Qubed.compress but also to maintain compression in the set operations above.
|
||||
"""
|
||||
# Now take the set of new children and see if any have identical key, metadata and children
|
||||
# Take the set of new children and see if any have identical key, metadata and children
|
||||
# the values may different and will be collapsed into a single node
|
||||
identical_children = defaultdict(set)
|
||||
for child in children:
|
||||
|
Loading…
x
Reference in New Issue
Block a user