More work on metadata
This commit is contained in:
parent
1259ff08b6
commit
4e777f295d
@ -291,8 +291,6 @@ class Qube:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for index, value in enumerate(self.values):
|
for index, value in enumerate(self.values):
|
||||||
# print(self.key, index, indices, value)
|
|
||||||
# print({k: np.shape(v) for k, v in self.metadata.items()})
|
|
||||||
indexed_metadata = {
|
indexed_metadata = {
|
||||||
k: vs[indices + (index,)] for k, vs in self.metadata.items()
|
k: vs[indices + (index,)] for k, vs in self.metadata.items()
|
||||||
}
|
}
|
||||||
|
@ -40,9 +40,10 @@ def QEnum_intersection(
|
|||||||
|
|
||||||
for index_a, val_A in enumerate(A.values):
|
for index_a, val_A in enumerate(A.values):
|
||||||
if val_A in B.values:
|
if val_A in B.values:
|
||||||
index_b = just_B.pop(val_A)
|
# print(f"{val_A} in both")
|
||||||
|
just_B.pop(val_A)
|
||||||
intersection[val_A] = (
|
intersection[val_A] = (
|
||||||
index_b # We throw away any overlapping metadata from B
|
index_a # We throw away any overlapping metadata from B
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
just_A[val_A] = index_a
|
just_A[val_A] = index_a
|
||||||
@ -56,9 +57,7 @@ def QEnum_intersection(
|
|||||||
|
|
||||||
just_A_out = ValuesMetadata(
|
just_A_out = ValuesMetadata(
|
||||||
values=QEnum(list(just_A.keys())),
|
values=QEnum(list(just_A.keys())),
|
||||||
metadata={
|
metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()},
|
||||||
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
just_B_out = ValuesMetadata(
|
just_B_out = ValuesMetadata(
|
||||||
@ -76,21 +75,21 @@ def node_intersection(
|
|||||||
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
|
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
|
||||||
return QEnum_intersection(A, B)
|
return QEnum_intersection(A, B)
|
||||||
|
|
||||||
if isinstance(A, WildcardGroup) and isinstance(B, WildcardGroup):
|
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
|
||||||
return A, ValuesMetadata(WildcardGroup(), {}), B
|
return A, ValuesMetadata(WildcardGroup(), {}), B
|
||||||
|
|
||||||
# If A is a wildcard matcher then the intersection is everything
|
# If A is a wildcard matcher then the intersection is everything
|
||||||
# just_A is still *
|
# just_A is still *
|
||||||
# just_B is empty
|
# just_B is empty
|
||||||
if isinstance(A, WildcardGroup):
|
if isinstance(A.values, WildcardGroup):
|
||||||
return A, B, ValuesMetadata(QEnum([]), {})
|
return A, B, ValuesMetadata(QEnum([]), {})
|
||||||
|
|
||||||
# The reverse if B is a wildcard
|
# The reverse if B is a wildcard
|
||||||
if isinstance(B, WildcardGroup):
|
if isinstance(B.values, WildcardGroup):
|
||||||
return ValuesMetadata(QEnum([]), {}), A, B
|
return ValuesMetadata(QEnum([]), {}), A, B
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Fused set operations on values types {type(A)} and {type(B)} not yet implemented"
|
f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -155,7 +154,7 @@ def _operation(
|
|||||||
values[node_b] = just_b
|
values[node_b] = just_b
|
||||||
|
|
||||||
if keep_intersection:
|
if keep_intersection:
|
||||||
if intersection:
|
if intersection.values:
|
||||||
new_node_a = replace(
|
new_node_a = replace(
|
||||||
node_a,
|
node_a,
|
||||||
data=replace(
|
data=replace(
|
||||||
@ -177,7 +176,7 @@ def _operation(
|
|||||||
# Now we've removed all the intersections we can yield the just_A and just_B parts if needed
|
# Now we've removed all the intersections we can yield the just_A and just_B parts if needed
|
||||||
if keep_just_A:
|
if keep_just_A:
|
||||||
for node in A:
|
for node in A:
|
||||||
if values[node]:
|
if values[node].values:
|
||||||
yield node_type.make(
|
yield node_type.make(
|
||||||
key,
|
key,
|
||||||
children=node.children,
|
children=node.children,
|
||||||
@ -186,7 +185,7 @@ def _operation(
|
|||||||
)
|
)
|
||||||
if keep_just_B:
|
if keep_just_B:
|
||||||
for node in B:
|
for node in B:
|
||||||
if values[node]:
|
if values[node].values:
|
||||||
yield node_type.make(
|
yield node_type.make(
|
||||||
key,
|
key,
|
||||||
children=node.children,
|
children=node.children,
|
||||||
@ -212,26 +211,41 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
|||||||
new_children = []
|
new_children = []
|
||||||
for child_set in identical_children.values():
|
for child_set in identical_children.values():
|
||||||
if len(child_set) > 1:
|
if len(child_set) > 1:
|
||||||
child_list = list(child_set)
|
child_set = list(child_set)
|
||||||
node_type = type(child_list[0])
|
example = child_set[0]
|
||||||
key = child_list[0].key
|
node_type = type(example)
|
||||||
|
key = child_set[0].key
|
||||||
|
|
||||||
# Compress the children into a single node
|
# Compress the children into a single node
|
||||||
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
|
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
|
||||||
"All children must have QEnum values"
|
"All children must have QEnum values"
|
||||||
)
|
)
|
||||||
|
|
||||||
node_data = NodeData(
|
metadata_groups = {
|
||||||
key=str(key),
|
k: [child.metadata[k] for child in child_set]
|
||||||
metadata=frozendict(), # Todo: Implement metadata compression
|
for k in example.metadata.keys()
|
||||||
values=QEnum((v for child in child_set for v in child.data.values)),
|
}
|
||||||
|
metadata: dict[str, np.ndarray] = frozendict(
|
||||||
|
{
|
||||||
|
k: np.concatenate(metadata_group, axis=-1)
|
||||||
|
for k, metadata_group in metadata_groups.items()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
new_child = node_type(data=node_data, children=child_list[0].children)
|
|
||||||
|
node_data = NodeData(
|
||||||
|
key=key,
|
||||||
|
metadata=metadata,
|
||||||
|
values=QEnum(
|
||||||
|
(v for child in child_set for v in child.data.values.values)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
new_child = node_type(data=node_data, children=child_set[0].children)
|
||||||
else:
|
else:
|
||||||
# If the group is size one just keep it
|
# If the group is size one just keep it
|
||||||
new_child = child_set.pop()
|
new_child = child_set.pop()
|
||||||
|
|
||||||
new_children.append(new_child)
|
new_children.append(new_child)
|
||||||
|
|
||||||
return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min()))))
|
return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min()))))
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ class QEnum(ValueGroup):
|
|||||||
values: EnumValuesType
|
values: EnumValuesType
|
||||||
|
|
||||||
def __init__(self, obj):
|
def __init__(self, obj):
|
||||||
object.__setattr__(self, "values", frozenset(obj))
|
object.__setattr__(self, "values", tuple(sorted(obj)))
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert isinstance(self.values, tuple)
|
assert isinstance(self.values, tuple)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user