cosmetics

This commit is contained in:
Tom 2025-04-24 10:28:52 +01:00
parent e04c0dd3bc
commit fa646aee77
4 changed files with 20 additions and 31 deletions

View File

@ -11,7 +11,7 @@ from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence from typing import Any, Iterable, Iterator, Literal, Sequence
import numpy as np import numpy as np
from frozendict import frozendict from frozendict import frozendict
@ -46,7 +46,7 @@ class Qube:
return self.data.values return self.data.values
@property @property
def metadata(self) -> Mapping[str, np.ndarray]: def metadata(self):
return self.data.metadata return self.data.metadata
def replace(self, **kwargs) -> Qube: def replace(self, **kwargs) -> Qube:

View File

@ -5,7 +5,7 @@ import numpy as np
from .value_types import QEnum from .value_types import QEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from .qube import Qube from .Qube import Qube
def make_node( def make_node(

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Mapping
import numpy as np import numpy as np
from frozendict import frozendict from frozendict import frozendict
@ -11,8 +10,8 @@ from .value_types import ValueGroup
class NodeData: class NodeData:
key: str key: str
values: ValueGroup values: ValueGroup
metadata: Mapping[str, np.ndarray] = field( metadata: frozendict[str, np.ndarray] = field(
default_factory=frozendict, compare=False default_factory=lambda: frozendict({}), compare=False
) )
def summary(self) -> str: def summary(self) -> str:

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, replace from dataclasses import dataclass
from enum import Enum from enum import Enum
# Prevent circular imports while allowing the type checker to know what Qube is # Prevent circular imports while allowing the type checker to know what Qube is
@ -155,21 +155,13 @@ def _operation(
if keep_intersection: if keep_intersection:
if intersection.values: if intersection.values:
new_node_a = replace( new_node_a = node_a.replace(
node_a, values=intersection.values,
data=replace( metadata=intersection.metadata,
node_a.data,
values=intersection.values,
metadata=intersection.metadata,
),
) )
new_node_b = replace( new_node_b = node_b.replace(
node_b, values=intersection.values,
data=replace( metadata=intersection.metadata,
node_b.data,
values=intersection.values,
metadata=intersection.metadata,
),
) )
yield operation(new_node_a, new_node_b, operation_type, node_type) yield operation(new_node_a, new_node_b, operation_type, node_type)
@ -211,21 +203,21 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
new_children = [] new_children = []
for child_set in identical_children.values(): for child_set in identical_children.values():
if len(child_set) > 1: if len(child_set) > 1:
child_set = list(child_set) child_list = list(child_set)
example = child_set[0] example = child_list[0]
node_type = type(example) node_type = type(example)
key = child_set[0].key key = child_list[0].key
# Compress the children into a single node # Compress the children into a single node
assert all(isinstance(child.data.values, QEnum) for child in child_set), ( assert all(isinstance(child.data.values, QEnum) for child in child_list), (
"All children must have QEnum values" "All children must have QEnum values"
) )
metadata_groups = { metadata_groups = {
k: [child.metadata[k] for child in child_set] k: [child.metadata[k] for child in child_list]
for k in example.metadata.keys() for k in example.metadata.keys()
} }
metadata: dict[str, np.ndarray] = frozendict( metadata: frozendict[str, np.ndarray] = frozendict(
{ {
k: np.concatenate(metadata_group, axis=-1) k: np.concatenate(metadata_group, axis=-1)
for k, metadata_group in metadata_groups.items() for k, metadata_group in metadata_groups.items()
@ -235,11 +227,9 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
node_data = NodeData( node_data = NodeData(
key=key, key=key,
metadata=metadata, metadata=metadata,
values=QEnum( values=QEnum((v for child in child_list for v in child.data.values)),
(v for child in child_set for v in child.data.values.values)
),
) )
new_child = node_type(data=node_data, children=child_set[0].children) new_child = node_type(data=node_data, children=child_list[0].children)
else: else:
# If the group is size one just keep it # If the group is size one just keep it
new_child = child_set.pop() new_child = child_set.pop()