cosmetics

This commit is contained in:
Tom 2025-04-24 10:28:52 +01:00
parent e04c0dd3bc
commit fa646aee77
4 changed files with 20 additions and 31 deletions

View File

@ -11,7 +11,7 @@ from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence
from typing import Any, Iterable, Iterator, Literal, Sequence
import numpy as np
from frozendict import frozendict
@ -46,7 +46,7 @@ class Qube:
return self.data.values
@property
def metadata(self) -> Mapping[str, np.ndarray]:
def metadata(self):
return self.data.metadata
def replace(self, **kwargs) -> Qube:

View File

@ -5,7 +5,7 @@ import numpy as np
from .value_types import QEnum
if TYPE_CHECKING:
from .qube import Qube
from .Qube import Qube
def make_node(

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from typing import Mapping
import numpy as np
from frozendict import frozendict
@ -11,8 +10,8 @@ from .value_types import ValueGroup
class NodeData:
key: str
values: ValueGroup
metadata: Mapping[str, np.ndarray] = field(
default_factory=frozendict, compare=False
metadata: frozendict[str, np.ndarray] = field(
default_factory=lambda: frozendict({}), compare=False
)
def summary(self) -> str:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, replace
from dataclasses import dataclass
from enum import Enum
# Prevent circular imports while allowing the type checker to know what Qube is
@ -155,21 +155,13 @@ def _operation(
if keep_intersection:
if intersection.values:
new_node_a = replace(
node_a,
data=replace(
node_a.data,
values=intersection.values,
metadata=intersection.metadata,
),
new_node_a = node_a.replace(
values=intersection.values,
metadata=intersection.metadata,
)
new_node_b = replace(
node_b,
data=replace(
node_b.data,
values=intersection.values,
metadata=intersection.metadata,
),
new_node_b = node_b.replace(
values=intersection.values,
metadata=intersection.metadata,
)
yield operation(new_node_a, new_node_b, operation_type, node_type)
@ -211,21 +203,21 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
new_children = []
for child_set in identical_children.values():
if len(child_set) > 1:
child_set = list(child_set)
example = child_set[0]
child_list = list(child_set)
example = child_list[0]
node_type = type(example)
key = child_set[0].key
key = child_list[0].key
# Compress the children into a single node
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
assert all(isinstance(child.data.values, QEnum) for child in child_list), (
"All children must have QEnum values"
)
metadata_groups = {
k: [child.metadata[k] for child in child_set]
k: [child.metadata[k] for child in child_list]
for k in example.metadata.keys()
}
metadata: dict[str, np.ndarray] = frozendict(
metadata: frozendict[str, np.ndarray] = frozendict(
{
k: np.concatenate(metadata_group, axis=-1)
for k, metadata_group in metadata_groups.items()
@ -235,11 +227,9 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
node_data = NodeData(
key=key,
metadata=metadata,
values=QEnum(
(v for child in child_set for v in child.data.values.values)
),
values=QEnum((v for child in child_list for v in child.data.values)),
)
new_child = node_type(data=node_data, children=child_set[0].children)
new_child = node_type(data=node_data, children=child_list[0].children)
else:
# If the group is size one just keep it
new_child = child_set.pop()