diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py
index 224eb0d..b9fb503 100644
--- a/src/python/qubed/Qube.py
+++ b/src/python/qubed/Qube.py
@@ -49,9 +49,15 @@ class Qube:
def metadata(self):
return self.data.metadata
+ @property
+ def dtype(self):
+ return self.data.dtype
+
def replace(self, **kwargs) -> Qube:
data_keys = {
- k: v for k, v in kwargs.items() if k in ["key", "values", "metadata"]
+ k: v
+ for k, v in kwargs.items()
+ if k in ["key", "values", "metadata", "dtype"]
}
node_keys = {k: v for k, v in kwargs.items() if k == "children"}
if not data_keys and not node_keys:
@@ -69,7 +75,9 @@ class Qube:
@classmethod
def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube:
return cls(
- data=NodeData(key, values, metadata=kwargs.get("metadata", frozendict())),
+ data=NodeData(
+ key, values, metadata=frozendict(kwargs.get("metadata", frozendict()))
+ ),
children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))),
)
@@ -217,11 +225,17 @@ class Qube:
if name is not None
else self
)
- return "".join(node_tree_to_string(node=node, depth=depth))
+ out = "".join(node_tree_to_string(node=node, depth=depth))
+ if out[-1] == "\n":
+ out = out[:-1]
+ return out
def __str__(self):
return self.__str_helper__()
+ def __repr__(self):
+ return f"Qube({self.__str_helper__()})"
+
def print(self, depth=None, name: str | None = None):
print(self.__str_helper__(depth=depth, name=name))
@@ -409,7 +423,8 @@ class Qube:
def convert(node: Qube) -> Qube:
if node.key in converters:
converter = converters[node.key]
- new_node = node.replace(values=QEnum(map(converter, node.values)))
+ values = [converter(v) for v in node.values]
+ new_node = node.replace(values=QEnum(values), dtype=type(values[0]))
return new_node
return node
diff --git a/src/python/qubed/metadata.py b/src/python/qubed/metadata.py
index 1436178..db460d5 100644
--- a/src/python/qubed/metadata.py
+++ b/src/python/qubed/metadata.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from typing import TYPE_CHECKING, Iterator
import numpy as np
@@ -13,7 +15,7 @@ def make_node(
key: str,
values: Iterator,
shape: list[int],
- children: "tuple[Qube]",
+ children: tuple[Qube, ...],
metadata: dict[str, np.ndarray] | None = None,
):
return cls.make(
@@ -30,11 +32,11 @@ def from_nodes(cls, nodes, add_root=True):
shape = [len(n["values"]) for n in nodes.values()]
nodes = nodes.items()
*nodes, (key, info) = nodes
- root = make_node(shape=shape, children=(), key=key, **info)
+ root = make_node(cls, shape=shape, children=(), key=key, **info)
for key, info in reversed(nodes):
shape.pop()
- root = make_node(shape=shape, children=(root,), key=key, **info)
+ root = make_node(cls, shape=shape, children=(root,), key=key, **info)
if add_root:
return cls.root_node(children=(root,))
diff --git a/src/python/qubed/node_types.py b/src/python/qubed/node_types.py
index 433198f..563d813 100644
--- a/src/python/qubed/node_types.py
+++ b/src/python/qubed/node_types.py
@@ -13,6 +13,7 @@ class NodeData:
metadata: frozendict[str, np.ndarray] = field(
default_factory=lambda: frozendict({}), compare=False
)
+ dtype: type = str
def summary(self) -> str:
return f"{self.key}={self.values.summary()}" if self.key != "root" else "root"
diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py
index 63f9797..037b952 100644
--- a/src/python/qubed/set_operations.py
+++ b/src/python/qubed/set_operations.py
@@ -40,7 +40,6 @@ def QEnum_intersection(
for index_a, val_A in enumerate(A.values):
if val_A in B.values:
- # print(f"{val_A} in both")
just_B.pop(val_A)
intersection[val_A] = (
index_a # We throw away any overlapping metadata from B
@@ -116,9 +115,8 @@ def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube
# For every node group, perform the set operation
for key, (A_nodes, B_nodes) in nodes_by_key.items():
- new_children.extend(
- _operation(key, A_nodes, B_nodes, operation_type, node_type)
- )
+ output = list(_operation(key, A_nodes, B_nodes, operation_type, node_type))
+ new_children.extend(output)
# Whenever we modify children we should recompress them
# But since `operation` is already recursive, we only need to compress this level not all levels
@@ -193,17 +191,17 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
"""
# Take the set of new children and see if any have identical key, metadata and children
# the values may different and will be collapsed into a single node
- identical_children = defaultdict(set)
+
+ identical_children = defaultdict(list)
for child in children:
# only care about the key and children of each node, ignore values
h = hash((child.key, tuple((cc.structural_hash for cc in child.children))))
- identical_children[h].add(child)
+ identical_children[h].append(child)
# Now go through and create new compressed nodes for any groups that need collapsing
new_children = []
- for child_set in identical_children.values():
- if len(child_set) > 1:
- child_list = list(child_set)
+ for child_list in identical_children.values():
+ if len(child_list) > 1:
example = child_list[0]
node_type = type(example)
key = child_list[0].key
@@ -217,9 +215,10 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
k: [child.metadata[k] for child in child_list]
for k in example.metadata.keys()
}
+
metadata: frozendict[str, np.ndarray] = frozendict(
{
- k: np.concatenate(metadata_group, axis=-1)
+ k: np.concatenate(metadata_group, axis=0)
for k, metadata_group in metadata_groups.items()
}
)
@@ -227,12 +226,14 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
node_data = NodeData(
key=key,
metadata=metadata,
- values=QEnum((v for child in child_list for v in child.data.values)),
+ values=QEnum(set(v for child in child_list for v in child.data.values)),
)
- new_child = node_type(data=node_data, children=child_list[0].children)
+ children = [cc for c in child_list for cc in c.children]
+ compressed_children = compress_children(children)
+ new_child = node_type(data=node_data, children=compressed_children)
else:
# If the group is size one just keep it
- new_child = child_set.pop()
+ new_child = child_list.pop()
new_children.append(new_child)
diff --git a/src/python/qubed/tree_formatters.py b/src/python/qubed/tree_formatters.py
index 7c83af8..49ead5a 100644
--- a/src/python/qubed/tree_formatters.py
+++ b/src/python/qubed/tree_formatters.py
@@ -4,6 +4,8 @@ import random
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterable
+import numpy as np
+
if TYPE_CHECKING:
from .Qube import Qube
@@ -68,17 +70,51 @@ def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[st
)
+def summarize_node_html(
+ node: Qube, collapse=False, max_summary_length=50, **kwargs
+) -> tuple[str, Qube]:
+ """
+ Extracts a summarized representation of the node while collapsing single-child paths.
+ Returns the summary string and the last node in the chain that has multiple children.
+ """
+ summaries = []
+
+ while True:
+ path = node.summary(**kwargs)
+ summary = path
+ if "is_leaf" in node.metadata and node.metadata["is_leaf"]:
+ summary += " 🌿"
+
+ if len(summary) > max_summary_length:
+ summary = summary[:max_summary_length] + "..."
+ info = (
+ f"dtype: {node.dtype.__name__}\n"
+ f"metadata: {dict((k, np.shape(v)) for k, v in node.metadata.items())}\n"
+ )
+ summary = f'{summary}'
+ summaries.append(summary)
+ if not collapse:
+ break
+
+ # Move down if there's exactly one child, otherwise stop
+ if len(node.children) != 1:
+ break
+ node = node.children[0]
+
+ return ", ".join(summaries), node
+
+
def _node_tree_to_html(
node: Qube, prefix: str = "", depth=1, connector="", **kwargs
) -> Iterable[str]:
- summary, path, node = summarize_node(node, **kwargs)
+ summary, node = summarize_node_html(node, **kwargs)
if len(node.children) == 0:
- yield f'{connector}{summary}'
+ yield f'{connector}{summary}'
return
else:
open = "open" if depth > 0 else ""
- yield f'{connector}{summary}
'
+ yield f'{connector}{summary}
'
for index, child in enumerate(node.children):
connector = "└── " if index == len(node.children) - 1 else "├── "
@@ -114,7 +150,7 @@ def node_tree_to_html(
margin-left: 0;
}
- .qubed-node a {
+ .qubed-level a {
margin-left: 10px;
text-decoration: none;
}
@@ -128,7 +164,7 @@ def node_tree_to_html(
display: block;
}
- summary:hover,span.leaf:hover {
+ span.qubed-node:hover {
background-color: #f0f0f0;
}
@@ -140,7 +176,7 @@ def node_tree_to_html(
content: " â–¼";
}
- .leaf {
+ .qubed-level {
text-overflow: ellipsis;
overflow: hidden;
text-wrap: nowrap;
diff --git a/tests/test_formatters.py b/tests/test_formatters.py
index 5eb8685..7278927 100644
--- a/tests/test_formatters.py
+++ b/tests/test_formatters.py
@@ -21,7 +21,7 @@ root
""".strip()
as_html = """
-root
├── class=od, expver=0001/0002, param=1/2└── class=rd
├── expver=0001, param=1/2/3 └── expver=0002, param=1/2
+root
├── class=od, expver=0001/0002, param=1/2└── class=rd
├── expver=0001, param=1/2/3 └── expver=0002, param=1/2
""".strip()
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index e69de29..6d9e416 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -0,0 +1,45 @@
+from frozendict import frozendict
+from qubed import Qube
+
+
+def make_set(entries):
+ return set((frozendict(a), frozendict(b)) for a, b in entries)
+
+
+def test_simple_union():
+ q = Qube.from_nodes(
+ {
+ "class": dict(values=["od", "rd"]),
+ "expver": dict(values=[1, 2]),
+ "stream": dict(
+ values=["a", "b", "c"], metadata=dict(number=list(range(12)))
+ ),
+ }
+ )
+
+ r = Qube.from_nodes(
+ {
+ "class": dict(values=["xd"]),
+ "expver": dict(values=[1, 2]),
+ "stream": dict(
+ values=["a", "b", "c"], metadata=dict(number=list(range(12, 18)))
+ ),
+ }
+ )
+
+ expected_union = Qube.from_nodes(
+ {
+ "class": dict(values=["od", "rd", "xd"]),
+ "expver": dict(values=[1, 2]),
+ "stream": dict(
+ values=["a", "b", "c"], metadata=dict(number=list(range(18)))
+ ),
+ }
+ )
+
+ union = q | r
+
+ assert union == expected_union
+ assert make_set(expected_union.leaves_with_metadata()) == make_set(
+ union.leaves_with_metadata()
+ )