More work on metadata
This commit is contained in:
parent
1259ff08b6
commit
4e777f295d
@ -291,8 +291,6 @@ class Qube:
|
||||
return
|
||||
|
||||
for index, value in enumerate(self.values):
|
||||
# print(self.key, index, indices, value)
|
||||
# print({k: np.shape(v) for k, v in self.metadata.items()})
|
||||
indexed_metadata = {
|
||||
k: vs[indices + (index,)] for k, vs in self.metadata.items()
|
||||
}
|
||||
|
@ -40,9 +40,10 @@ def QEnum_intersection(
|
||||
|
||||
for index_a, val_A in enumerate(A.values):
|
||||
if val_A in B.values:
|
||||
index_b = just_B.pop(val_A)
|
||||
# print(f"{val_A} in both")
|
||||
just_B.pop(val_A)
|
||||
intersection[val_A] = (
|
||||
index_b # We throw away any overlapping metadata from B
|
||||
index_a # We throw away any overlapping metadata from B
|
||||
)
|
||||
else:
|
||||
just_A[val_A] = index_a
|
||||
@ -56,9 +57,7 @@ def QEnum_intersection(
|
||||
|
||||
just_A_out = ValuesMetadata(
|
||||
values=QEnum(list(just_A.keys())),
|
||||
metadata={
|
||||
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
|
||||
},
|
||||
metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()},
|
||||
)
|
||||
|
||||
just_B_out = ValuesMetadata(
|
||||
@ -76,21 +75,21 @@ def node_intersection(
|
||||
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
|
||||
return QEnum_intersection(A, B)
|
||||
|
||||
if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup):
|
||||
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
|
||||
return A, ValuesMetadata(WildcardGroup(), {}), B
|
||||
|
||||
# If A is a wildcard matcher then the intersection is everything
|
||||
# just_A is still *
|
||||
# just_B is empty
|
||||
if isinstance(A, WildcardGroup):
|
||||
if isinstance(A.values, WildcardGroup):
|
||||
return A, B, ValuesMetadata(QEnum([]), {})
|
||||
|
||||
# The reverse if B is a wildcard
|
||||
if isinstance(B, WildcardGroup):
|
||||
if isinstance(B.values, WildcardGroup):
|
||||
return ValuesMetadata(QEnum([]), {}), A, B
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented"
|
||||
f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
|
||||
)
|
||||
|
||||
|
||||
@ -155,7 +154,7 @@ def _operation(
|
||||
values[node_b] = just_b
|
||||
|
||||
if keep_intersection:
|
||||
if intersection:
|
||||
if intersection.values:
|
||||
new_node_a = replace(
|
||||
node_a,
|
||||
data=replace(
|
||||
@ -177,7 +176,7 @@ def _operation(
|
||||
# Now we've removed all the intersections we can yield the just_A and just_B parts if needed
|
||||
if keep_just_A:
|
||||
for node in A:
|
||||
if values[node]:
|
||||
if values[node].values:
|
||||
yield node_type.make(
|
||||
key,
|
||||
children=node.children,
|
||||
@ -186,7 +185,7 @@ def _operation(
|
||||
)
|
||||
if keep_just_B:
|
||||
for node in B:
|
||||
if values[node]:
|
||||
if values[node].values:
|
||||
yield node_type.make(
|
||||
key,
|
||||
children=node.children,
|
||||
@ -212,26 +211,41 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
||||
new_children = []
|
||||
for child_set in identical_children.values():
|
||||
if len(child_set) > 1:
|
||||
child_list = list(child_set)
|
||||
node_type = type(child_list[0])
|
||||
key = child_list[0].key
|
||||
child_set = list(child_set)
|
||||
example = child_set[0]
|
||||
node_type = type(example)
|
||||
key = child_set[0].key
|
||||
|
||||
# Compress the children into a single node
|
||||
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
|
||||
"All children must have QEnum values"
|
||||
)
|
||||
|
||||
node_data = NodeData(
|
||||
key=str(key),
|
||||
metadata=frozendict(), # Todo: Implement metadata compression
|
||||
values=QEnum((v for child in child_set for v in child.data.values)),
|
||||
metadata_groups = {
|
||||
k: [child.metadata[k] for child in child_set]
|
||||
for k in example.metadata.keys()
|
||||
}
|
||||
metadata: dict[str, np.ndarray] = frozendict(
|
||||
{
|
||||
k: np.concatenate(metadata_group, axis=-1)
|
||||
for k, metadata_group in metadata_groups.items()
|
||||
}
|
||||
)
|
||||
new_child = node_type(data=node_data, children=child_list[0].children)
|
||||
|
||||
node_data = NodeData(
|
||||
key=key,
|
||||
metadata=metadata,
|
||||
values=QEnum(
|
||||
(v for child in child_set for v in child.data.values.values)
|
||||
),
|
||||
)
|
||||
new_child = node_type(data=node_data, children=child_set[0].children)
|
||||
else:
|
||||
# If the group is size one just keep it
|
||||
new_child = child_set.pop()
|
||||
|
||||
new_children.append(new_child)
|
||||
|
||||
return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min()))))
|
||||
|
||||
|
||||
|
@ -71,7 +71,7 @@ class QEnum(ValueGroup):
|
||||
values: EnumValuesType
|
||||
|
||||
def __init__(self, obj):
|
||||
object.__setattr__(self, "values", frozenset(obj))
|
||||
object.__setattr__(self, "values", tuple(sorted(obj)))
|
||||
|
||||
def __post_init__(self):
|
||||
assert isinstance(self.values, tuple)
|
||||
|
Loading…
x
Reference in New Issue
Block a user