diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index 224eb0d..b9fb503 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -49,9 +49,15 @@ class Qube: def metadata(self): return self.data.metadata + @property + def dtype(self): + return self.data.dtype + def replace(self, **kwargs) -> Qube: data_keys = { - k: v for k, v in kwargs.items() if k in ["key", "values", "metadata"] + k: v + for k, v in kwargs.items() + if k in ["key", "values", "metadata", "dtype"] } node_keys = {k: v for k, v in kwargs.items() if k == "children"} if not data_keys and not node_keys: @@ -69,7 +75,9 @@ class Qube: @classmethod def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube: return cls( - data=NodeData(key, values, metadata=kwargs.get("metadata", frozendict())), + data=NodeData( + key, values, metadata=frozendict(kwargs.get("metadata", frozendict())) + ), children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))), ) @@ -217,11 +225,17 @@ class Qube: if name is not None else self ) - return "".join(node_tree_to_string(node=node, depth=depth)) + out = "".join(node_tree_to_string(node=node, depth=depth)) + if out[-1] == "\n": + out = out[:-1] + return out def __str__(self): return self.__str_helper__() + def __repr__(self): + return f"Qube({self.__str_helper__()})" + def print(self, depth=None, name: str | None = None): print(self.__str_helper__(depth=depth, name=name)) @@ -409,7 +423,8 @@ class Qube: def convert(node: Qube) -> Qube: if node.key in converters: converter = converters[node.key] - new_node = node.replace(values=QEnum(map(converter, node.values))) + values = [converter(v) for v in node.values] + new_node = node.replace(values=QEnum(values), dtype=type(values[0])) return new_node return node diff --git a/src/python/qubed/metadata.py b/src/python/qubed/metadata.py index 1436178..db460d5 100644 --- a/src/python/qubed/metadata.py +++ b/src/python/qubed/metadata.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, Iterator import numpy as np @@ -13,7 +15,7 @@ def make_node( key: str, values: Iterator, shape: list[int], - children: "tuple[Qube]", + children: tuple[Qube, ...], metadata: dict[str, np.ndarray] | None = None, ): return cls.make( @@ -30,11 +32,11 @@ def from_nodes(cls, nodes, add_root=True): shape = [len(n["values"]) for n in nodes.values()] nodes = nodes.items() *nodes, (key, info) = nodes - root = make_node(shape=shape, children=(), key=key, **info) + root = make_node(cls, shape=shape, children=(), key=key, **info) for key, info in reversed(nodes): shape.pop() - root = make_node(shape=shape, children=(root,), key=key, **info) + root = make_node(cls, shape=shape, children=(root,), key=key, **info) if add_root: return cls.root_node(children=(root,)) diff --git a/src/python/qubed/node_types.py b/src/python/qubed/node_types.py index 433198f..563d813 100644 --- a/src/python/qubed/node_types.py +++ b/src/python/qubed/node_types.py @@ -13,6 +13,7 @@ class NodeData: metadata: frozendict[str, np.ndarray] = field( default_factory=lambda: frozendict({}), compare=False ) + dtype: type = str def summary(self) -> str: return f"{self.key}={self.values.summary()}" if self.key != "root" else "root" diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index 63f9797..037b952 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -40,7 +40,6 @@ def QEnum_intersection( for index_a, val_A in enumerate(A.values): if val_A in B.values: - # print(f"{val_A} in both") just_B.pop(val_A) intersection[val_A] = ( index_a # We throw away any overlapping metadata from B @@ -116,9 +115,8 @@ def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube # For every node group, perform the set operation for key, (A_nodes, B_nodes) in nodes_by_key.items(): - new_children.extend( - _operation(key, A_nodes, B_nodes, operation_type, node_type) - ) + output = list(_operation(key, A_nodes, B_nodes, operation_type, node_type)) + new_children.extend(output) # Whenever we modify children we should recompress them # But since `operation` is already recursive, we only need to compress this level not all levels @@ -193,17 +191,17 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: """ # Take the set of new children and see if any have identical key, metadata and children # the values may different and will be collapsed into a single node - identical_children = defaultdict(set) + + identical_children = defaultdict(list) for child in children: # only care about the key and children of each node, ignore values h = hash((child.key, tuple((cc.structural_hash for cc in child.children)))) - identical_children[h].add(child) + identical_children[h].append(child) # Now go through and create new compressed nodes for any groups that need collapsing new_children = [] - for child_set in identical_children.values(): - if len(child_set) > 1: - child_list = list(child_set) + for child_list in identical_children.values(): + if len(child_list) > 1: example = child_list[0] node_type = type(example) key = child_list[0].key @@ -217,9 +215,10 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: k: [child.metadata[k] for child in child_list] for k in example.metadata.keys() } + metadata: frozendict[str, np.ndarray] = frozendict( { - k: np.concatenate(metadata_group, axis=-1) + k: np.concatenate(metadata_group, axis=0) for k, metadata_group in metadata_groups.items() } ) @@ -227,12 +226,14 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: node_data = NodeData( key=key, metadata=metadata, - values=QEnum((v for child in child_list for v in child.data.values)), + values=QEnum(set(v for child in child_list for v in child.data.values)), ) - new_child = node_type(data=node_data, children=child_list[0].children) + children = [cc for c in child_list for cc in c.children] + compressed_children = compress_children(children) + new_child = node_type(data=node_data, children=compressed_children) else: # If the group is size one just keep it - new_child = child_set.pop() + new_child = child_list.pop() new_children.append(new_child) diff --git a/src/python/qubed/tree_formatters.py b/src/python/qubed/tree_formatters.py index 7c83af8..49ead5a 100644 --- a/src/python/qubed/tree_formatters.py +++ b/src/python/qubed/tree_formatters.py @@ -4,6 +4,8 @@ import random from dataclasses import dataclass from typing import TYPE_CHECKING, Iterable +import numpy as np + if TYPE_CHECKING: from .Qube import Qube @@ -68,17 +70,51 @@ def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[st ) +def summarize_node_html( + node: Qube, collapse=False, max_summary_length=50, **kwargs +) -> tuple[str, Qube]: + """ + Extracts a summarized representation of the node while collapsing single-child paths. + Returns the summary string and the last node in the chain that has multiple children. + """ + summaries = [] + + while True: + path = node.summary(**kwargs) + summary = path + if "is_leaf" in node.metadata and node.metadata["is_leaf"]: + summary += " 🌿" + + if len(summary) > max_summary_length: + summary = summary[:max_summary_length] + "..." + info = ( + f"dtype: {node.dtype.__name__}\n" + f"metadata: {dict((k, np.shape(v)) for k, v in node.metadata.items())}\n" + ) + summary = f'{summary}' + summaries.append(summary) + if not collapse: + break + + # Move down if there's exactly one child, otherwise stop + if len(node.children) != 1: + break + node = node.children[0] + + return ", ".join(summaries), node + + def _node_tree_to_html( node: Qube, prefix: str = "", depth=1, connector="", **kwargs ) -> Iterable[str]: - summary, path, node = summarize_node(node, **kwargs) + summary, node = summarize_node_html(node, **kwargs) if len(node.children) == 0: - yield f'{connector}{summary}' + yield f'{connector}{summary}' return else: open = "open" if depth > 0 else "" - yield f'
{connector}{summary}' + yield f'
{connector}{summary}' for index, child in enumerate(node.children): connector = "└── " if index == len(node.children) - 1 else "├── " @@ -114,7 +150,7 @@ def node_tree_to_html( margin-left: 0; } - .qubed-node a { + .qubed-level a { margin-left: 10px; text-decoration: none; } @@ -128,7 +164,7 @@ def node_tree_to_html( display: block; } - summary:hover,span.leaf:hover { + span.qubed-node:hover { background-color: #f0f0f0; } @@ -140,7 +176,7 @@ def node_tree_to_html( content: " ▼"; } - .leaf { + .qubed-level { text-overflow: ellipsis; overflow: hidden; text-wrap: nowrap; diff --git a/tests/test_formatters.py b/tests/test_formatters.py index 5eb8685..7278927 100644 --- a/tests/test_formatters.py +++ b/tests/test_formatters.py @@ -21,7 +21,7 @@ root """.strip() as_html = """ -
root├── class=od, expver=0001/0002, param=1/2
└── class=rd ├── expver=0001, param=1/2/3 └── expver=0002, param=1/2
+
root├── class=od, expver=0001/0002, param=1/2
└── class=rd ├── expver=0001, param=1/2/3 └── expver=0002, param=1/2
""".strip() diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e69de29..6d9e416 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -0,0 +1,45 @@ +from frozendict import frozendict +from qubed import Qube + + +def make_set(entries): + return set((frozendict(a), frozendict(b)) for a, b in entries) + + +def test_simple_union(): + q = Qube.from_nodes( + { + "class": dict(values=["od", "rd"]), + "expver": dict(values=[1, 2]), + "stream": dict( + values=["a", "b", "c"], metadata=dict(number=list(range(12))) + ), + } + ) + + r = Qube.from_nodes( + { + "class": dict(values=["xd"]), + "expver": dict(values=[1, 2]), + "stream": dict( + values=["a", "b", "c"], metadata=dict(number=list(range(12, 18))) + ), + } + ) + + expected_union = Qube.from_nodes( + { + "class": dict(values=["od", "rd", "xd"]), + "expver": dict(values=[1, 2]), + "stream": dict( + values=["a", "b", "c"], metadata=dict(number=list(range(18))) + ), + } + ) + + union = q | r + + assert union == expected_union + assert make_set(expected_union.leaves_with_metadata()) == make_set( + union.leaves_with_metadata() + )