cosmetics
This commit is contained in:
parent
e04c0dd3bc
commit
fa646aee77
@ -11,7 +11,7 @@ from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence
|
||||
from typing import Any, Iterable, Iterator, Literal, Sequence
|
||||
|
||||
import numpy as np
|
||||
from frozendict import frozendict
|
||||
@ -46,7 +46,7 @@ class Qube:
|
||||
return self.data.values
|
||||
|
||||
@property
|
||||
def metadata(self) -> Mapping[str, np.ndarray]:
|
||||
def metadata(self):
|
||||
return self.data.metadata
|
||||
|
||||
def replace(self, **kwargs) -> Qube:
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
from .value_types import QEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .qube import Qube
|
||||
from .Qube import Qube
|
||||
|
||||
|
||||
def make_node(
|
||||
|
@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Mapping
|
||||
|
||||
import numpy as np
|
||||
from frozendict import frozendict
|
||||
@ -11,8 +10,8 @@ from .value_types import ValueGroup
|
||||
class NodeData:
|
||||
key: str
|
||||
values: ValueGroup
|
||||
metadata: Mapping[str, np.ndarray] = field(
|
||||
default_factory=frozendict, compare=False
|
||||
metadata: frozendict[str, np.ndarray] = field(
|
||||
default_factory=lambda: frozendict({}), compare=False
|
||||
)
|
||||
|
||||
def summary(self) -> str:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
# Prevent circular imports while allowing the type checker to know what Qube is
|
||||
@ -155,21 +155,13 @@ def _operation(
|
||||
|
||||
if keep_intersection:
|
||||
if intersection.values:
|
||||
new_node_a = replace(
|
||||
node_a,
|
||||
data=replace(
|
||||
node_a.data,
|
||||
values=intersection.values,
|
||||
metadata=intersection.metadata,
|
||||
),
|
||||
new_node_a = node_a.replace(
|
||||
values=intersection.values,
|
||||
metadata=intersection.metadata,
|
||||
)
|
||||
new_node_b = replace(
|
||||
node_b,
|
||||
data=replace(
|
||||
node_b.data,
|
||||
values=intersection.values,
|
||||
metadata=intersection.metadata,
|
||||
),
|
||||
new_node_b = node_b.replace(
|
||||
values=intersection.values,
|
||||
metadata=intersection.metadata,
|
||||
)
|
||||
yield operation(new_node_a, new_node_b, operation_type, node_type)
|
||||
|
||||
@ -211,21 +203,21 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
||||
new_children = []
|
||||
for child_set in identical_children.values():
|
||||
if len(child_set) > 1:
|
||||
child_set = list(child_set)
|
||||
example = child_set[0]
|
||||
child_list = list(child_set)
|
||||
example = child_list[0]
|
||||
node_type = type(example)
|
||||
key = child_set[0].key
|
||||
key = child_list[0].key
|
||||
|
||||
# Compress the children into a single node
|
||||
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
|
||||
assert all(isinstance(child.data.values, QEnum) for child in child_list), (
|
||||
"All children must have QEnum values"
|
||||
)
|
||||
|
||||
metadata_groups = {
|
||||
k: [child.metadata[k] for child in child_set]
|
||||
k: [child.metadata[k] for child in child_list]
|
||||
for k in example.metadata.keys()
|
||||
}
|
||||
metadata: dict[str, np.ndarray] = frozendict(
|
||||
metadata: frozendict[str, np.ndarray] = frozendict(
|
||||
{
|
||||
k: np.concatenate(metadata_group, axis=-1)
|
||||
for k, metadata_group in metadata_groups.items()
|
||||
@ -235,11 +227,9 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
||||
node_data = NodeData(
|
||||
key=key,
|
||||
metadata=metadata,
|
||||
values=QEnum(
|
||||
(v for child in child_set for v in child.data.values.values)
|
||||
),
|
||||
values=QEnum((v for child in child_list for v in child.data.values)),
|
||||
)
|
||||
new_child = node_type(data=node_data, children=child_set[0].children)
|
||||
new_child = node_type(data=node_data, children=child_list[0].children)
|
||||
else:
|
||||
# If the group is size one just keep it
|
||||
new_child = child_set.pop()
|
||||
|
Loading…
x
Reference in New Issue
Block a user