From fa646aee7718307be17c721f9cccbb5649963816 Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 24 Apr 2025 10:28:52 +0100 Subject: [PATCH] cosmetics --- src/python/qubed/Qube.py | 4 +-- src/python/qubed/metadata.py | 2 +- src/python/qubed/node_types.py | 5 ++-- src/python/qubed/set_operations.py | 40 +++++++++++------------------- 4 files changed, 20 insertions(+), 31 deletions(-) diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index cbfe7da..dc838b7 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence +from typing import Any, Iterable, Iterator, Literal, Sequence import numpy as np from frozendict import frozendict @@ -46,7 +46,7 @@ class Qube: return self.data.values @property - def metadata(self) -> Mapping[str, np.ndarray]: + def metadata(self): return self.data.metadata def replace(self, **kwargs) -> Qube: diff --git a/src/python/qubed/metadata.py b/src/python/qubed/metadata.py index 4719d1c..1436178 100644 --- a/src/python/qubed/metadata.py +++ b/src/python/qubed/metadata.py @@ -5,7 +5,7 @@ import numpy as np from .value_types import QEnum if TYPE_CHECKING: - from .qube import Qube + from .Qube import Qube def make_node( diff --git a/src/python/qubed/node_types.py b/src/python/qubed/node_types.py index 7293687..433198f 100644 --- a/src/python/qubed/node_types.py +++ b/src/python/qubed/node_types.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import Mapping import numpy as np from frozendict import frozendict @@ -11,8 +10,8 @@ from .value_types import ValueGroup class NodeData: key: str values: ValueGroup - metadata: Mapping[str, np.ndarray] = field( - default_factory=frozendict, compare=False + metadata: frozendict[str, np.ndarray] = field( + default_factory=lambda: frozendict({}), compare=False ) def summary(self) -> str: diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index 51cb85e..63f9797 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import dataclass, replace +from dataclasses import dataclass from enum import Enum # Prevent circular imports while allowing the type checker to know what Qube is @@ -155,21 +155,13 @@ def _operation( if keep_intersection: if intersection.values: - new_node_a = replace( - node_a, - data=replace( - node_a.data, - values=intersection.values, - metadata=intersection.metadata, - ), + new_node_a = node_a.replace( + values=intersection.values, + metadata=intersection.metadata, ) - new_node_b = replace( - node_b, - data=replace( - node_b.data, - values=intersection.values, - metadata=intersection.metadata, - ), + new_node_b = node_b.replace( + values=intersection.values, + metadata=intersection.metadata, ) yield operation(new_node_a, new_node_b, operation_type, node_type) @@ -211,21 +203,21 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: new_children = [] for child_set in identical_children.values(): if len(child_set) > 1: - child_set = list(child_set) - example = child_set[0] + child_list = list(child_set) + example = child_list[0] node_type = type(example) - key = child_set[0].key + key = child_list[0].key # Compress the children into a single node - assert all(isinstance(child.data.values, QEnum) for child in child_set), ( + assert all(isinstance(child.data.values, QEnum) for child in child_list), ( "All children must have QEnum values" ) metadata_groups = { - k: [child.metadata[k] for child in child_set] + k: [child.metadata[k] for child in child_list] for k in example.metadata.keys() } - metadata: dict[str, np.ndarray] = frozendict( + metadata: frozendict[str, np.ndarray] = frozendict( { k: np.concatenate(metadata_group, axis=-1) for k, metadata_group in metadata_groups.items() @@ -235,11 +227,9 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: node_data = NodeData( key=key, metadata=metadata, - values=QEnum( - (v for child in child_set for v in child.data.values.values) - ), + values=QEnum((v for child in child_list for v in child.data.values)), ) - new_child = node_type(data=node_data, children=child_set[0].children) + new_child = node_type(data=node_data, children=child_list[0].children) else: # If the group is size one just keep it new_child = child_set.pop()