cosmetics
This commit is contained in:
parent
e04c0dd3bc
commit
fa646aee77
@ -11,7 +11,7 @@ from collections.abc import Callable
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence
|
from typing import Any, Iterable, Iterator, Literal, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
@ -46,7 +46,7 @@ class Qube:
|
|||||||
return self.data.values
|
return self.data.values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metadata(self) -> Mapping[str, np.ndarray]:
|
def metadata(self):
|
||||||
return self.data.metadata
|
return self.data.metadata
|
||||||
|
|
||||||
def replace(self, **kwargs) -> Qube:
|
def replace(self, **kwargs) -> Qube:
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from .value_types import QEnum
|
from .value_types import QEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .qube import Qube
|
from .Qube import Qube
|
||||||
|
|
||||||
|
|
||||||
def make_node(
|
def make_node(
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Mapping
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
@ -11,8 +10,8 @@ from .value_types import ValueGroup
|
|||||||
class NodeData:
|
class NodeData:
|
||||||
key: str
|
key: str
|
||||||
values: ValueGroup
|
values: ValueGroup
|
||||||
metadata: Mapping[str, np.ndarray] = field(
|
metadata: frozendict[str, np.ndarray] = field(
|
||||||
default_factory=frozendict, compare=False
|
default_factory=lambda: frozendict({}), compare=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def summary(self) -> str:
|
def summary(self) -> str:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
# Prevent circular imports while allowing the type checker to know what Qube is
|
# Prevent circular imports while allowing the type checker to know what Qube is
|
||||||
@ -155,21 +155,13 @@ def _operation(
|
|||||||
|
|
||||||
if keep_intersection:
|
if keep_intersection:
|
||||||
if intersection.values:
|
if intersection.values:
|
||||||
new_node_a = replace(
|
new_node_a = node_a.replace(
|
||||||
node_a,
|
|
||||||
data=replace(
|
|
||||||
node_a.data,
|
|
||||||
values=intersection.values,
|
values=intersection.values,
|
||||||
metadata=intersection.metadata,
|
metadata=intersection.metadata,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
new_node_b = replace(
|
new_node_b = node_b.replace(
|
||||||
node_b,
|
|
||||||
data=replace(
|
|
||||||
node_b.data,
|
|
||||||
values=intersection.values,
|
values=intersection.values,
|
||||||
metadata=intersection.metadata,
|
metadata=intersection.metadata,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
yield operation(new_node_a, new_node_b, operation_type, node_type)
|
yield operation(new_node_a, new_node_b, operation_type, node_type)
|
||||||
|
|
||||||
@ -211,21 +203,21 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
|||||||
new_children = []
|
new_children = []
|
||||||
for child_set in identical_children.values():
|
for child_set in identical_children.values():
|
||||||
if len(child_set) > 1:
|
if len(child_set) > 1:
|
||||||
child_set = list(child_set)
|
child_list = list(child_set)
|
||||||
example = child_set[0]
|
example = child_list[0]
|
||||||
node_type = type(example)
|
node_type = type(example)
|
||||||
key = child_set[0].key
|
key = child_list[0].key
|
||||||
|
|
||||||
# Compress the children into a single node
|
# Compress the children into a single node
|
||||||
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
|
assert all(isinstance(child.data.values, QEnum) for child in child_list), (
|
||||||
"All children must have QEnum values"
|
"All children must have QEnum values"
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_groups = {
|
metadata_groups = {
|
||||||
k: [child.metadata[k] for child in child_set]
|
k: [child.metadata[k] for child in child_list]
|
||||||
for k in example.metadata.keys()
|
for k in example.metadata.keys()
|
||||||
}
|
}
|
||||||
metadata: dict[str, np.ndarray] = frozendict(
|
metadata: frozendict[str, np.ndarray] = frozendict(
|
||||||
{
|
{
|
||||||
k: np.concatenate(metadata_group, axis=-1)
|
k: np.concatenate(metadata_group, axis=-1)
|
||||||
for k, metadata_group in metadata_groups.items()
|
for k, metadata_group in metadata_groups.items()
|
||||||
@ -235,11 +227,9 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
|||||||
node_data = NodeData(
|
node_data = NodeData(
|
||||||
key=key,
|
key=key,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
values=QEnum(
|
values=QEnum((v for child in child_list for v in child.data.values)),
|
||||||
(v for child in child_set for v in child.data.values.values)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
new_child = node_type(data=node_data, children=child_set[0].children)
|
new_child = node_type(data=node_data, children=child_list[0].children)
|
||||||
else:
|
else:
|
||||||
# If the group is size one just keep it
|
# If the group is size one just keep it
|
||||||
new_child = child_set.pop()
|
new_child = child_set.pop()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user