From 35bb8f0edd487fd05519f786d4bc37666e40da4f Mon Sep 17 00:00:00 2001
From: Tom <thomas.hodson@ecmwf.int>
Date: Wed, 14 May 2025 10:14:02 +0100
Subject: [PATCH] Massive rewrite

---
 src/python/qubed/Qube.py              | 228 +++++++++++++++++---------
 src/python/qubed/__init__.py          |   3 +-
 src/python/qubed/metadata.py          |   4 +-
 src/python/qubed/node_types.py        |  27 ---
 src/python/qubed/protobuf/__init__.py |   0
 src/python/qubed/protobuf/adapters.py |  99 +++++++++++
 src/python/qubed/protobuf/qube_pb2.py |  45 +++++
 src/python/qubed/set_operations.py    | 153 +++++++++++------
 src/python/qubed/tree_formatters.py   |  50 ++++--
 src/python/qubed/value_types.py       |  21 ++-
 src/qube.proto                        |  32 ++++
 stac_server/main.py                   |   2 +-
 tests/test_basic_operations.py        |  60 +++----
 tests/test_conversions.py             |  55 ++++++-
 tests/test_formatters.py              |   8 -
 tests/test_metadata.py                |  67 ++++----
 tests/test_protobuf.py                |  12 ++
 17 files changed, 623 insertions(+), 243 deletions(-)
 delete mode 100644 src/python/qubed/node_types.py
 create mode 100644 src/python/qubed/protobuf/__init__.py
 create mode 100644 src/python/qubed/protobuf/adapters.py
 create mode 100644 src/python/qubed/protobuf/qube_pb2.py
 create mode 100644 src/qube.proto
 create mode 100644 tests/test_protobuf.py

diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py
index b9fb503..79a004e 100644
--- a/src/python/qubed/Qube.py
+++ b/src/python/qubed/Qube.py
@@ -8,17 +8,17 @@ import functools
 import json
 from collections import defaultdict
 from collections.abc import Callable
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 from functools import cached_property
 from pathlib import Path
-from typing import Any, Iterable, Iterator, Literal, Sequence
+from typing import Any, Iterable, Iterator, Literal, Self, Sequence
 
 import numpy as np
 from frozendict import frozendict
 
 from . import set_operations
 from .metadata import from_nodes
-from .node_types import NodeData, RootNodeData
+from .protobuf.adapters import proto_to_qube, qube_to_proto
 from .tree_formatters import (
     HTML,
     node_tree_to_html,
@@ -32,58 +32,90 @@ from .value_types import (
 )
 
 
-@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True)
-class Qube:
-    data: NodeData
-    children: tuple[Qube, ...]
+@dataclass
+class AxisInfo:
+    key: str
+    type: Any
+    depths: set[int]
+    values: set
 
-    @property
-    def key(self) -> str:
-        return self.data.key
+    def combine(self, other: Self):
+        self.key = other.key
+        self.type = other.type
+        self.depths.update(other.depths)
+        self.values.update(other.values)
+        # print(f"combining {self} and {other} getting {result}")
 
-    @property
-    def values(self) -> ValueGroup:
-        return self.data.values
-
-    @property
-    def metadata(self):
-        return self.data.metadata
-
-    @property
-    def dtype(self):
-        return self.data.dtype
-
-    def replace(self, **kwargs) -> Qube:
-        data_keys = {
-            k: v
-            for k, v in kwargs.items()
-            if k in ["key", "values", "metadata", "dtype"]
+    def to_json(self):
+        return {
+            "key": self.key,
+            "type": self.type.__name__,
+            "values": list(self.values),
+            "depths": list(self.depths),
         }
-        node_keys = {k: v for k, v in kwargs.items() if k == "children"}
-        if not data_keys and not node_keys:
-            return self
-        if not data_keys:
-            return dataclasses.replace(self, **node_keys)
 
-        return dataclasses.replace(
-            self, data=dataclasses.replace(self.data, **data_keys), **node_keys
-        )
+
+@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True)
+class QubeNamedRoot:
+    "Helper class to print a custom root name"
+
+    key: str
+    dtype: str = "str"
+    children: tuple[Qube, ...] = ()
 
     def summary(self) -> str:
-        return self.data.summary()
+        return self.key
+
+
+@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True)
+class Qube:
+    key: str
+    values: ValueGroup
+    metadata: frozendict[str, np.ndarray] = field(
+        default_factory=lambda: frozendict({}), compare=False
+    )
+    children: tuple[Qube, ...] = ()
+    is_root: bool = False
+
+    def replace(self, **kwargs) -> Qube:
+        return dataclasses.replace(self, **kwargs)
+
+    def summary(self) -> str:
+        if self.is_root:
+            return self.key
+        return f"{self.key}={self.values.summary()}" if self.key != "root" else "root"
 
     @classmethod
-    def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube:
+    def make_node(
+        cls,
+        key: str,
+        values: Iterable | QEnum | WildcardGroup,
+        children: Iterable[Qube],
+        metadata: dict[str, np.ndarray] = {},
+        is_root: bool = False,
+    ) -> Qube:
+        if isinstance(values, ValueGroup):
+            values = values
+        else:
+            values = QEnum(values)
+
         return cls(
-            data=NodeData(
-                key, values, metadata=frozendict(kwargs.get("metadata", frozendict()))
-            ),
+            key,
+            values=values,
             children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))),
+            metadata=frozendict(metadata),
+            is_root=is_root,
         )
 
     @classmethod
-    def root_node(cls, children: Iterable[Qube]) -> Qube:
-        return cls.make("root", QEnum(("root",)), children)
+    def make_root(cls, children: Iterable[Qube], metadata={}) -> Qube:
+        return cls.make_node(
+            "root",
+            values=QEnum(("root",)),
+            children=children,
+            metadata=metadata,
+            is_root=True,
+        )
 
     @classmethod
     def load(cls, path: str | Path) -> Qube:
@@ -104,18 +136,19 @@ class Qube:
             else:
                 values_group = QEnum([values])
 
-            children = [cls.make(key, values_group, children)]
+            children = [cls.make_node(key, values_group, children)]
 
-        return cls.root_node(children)
+        return cls.make_root(children)
 
     @classmethod
     def from_json(cls, json: dict) -> Qube:
-        def from_json(json: dict) -> Qube:
-            return Qube.make(
+        def from_json(json: dict, depth=0) -> Qube:
+            return Qube.make_node(
                 key=json["key"],
                 values=values_from_json(json["values"]),
                 metadata=frozendict(json["metadata"]) if "metadata" in json else {},
-                children=(from_json(c) for c in json["children"]),
+                children=(from_json(c, depth + 1) for c in json["children"]),
+                is_root=(depth == 0),
             )
 
         return from_json(json)
@@ -146,13 +179,13 @@ class Qube:
                 else:
                     values = QEnum(values)
 
-                yield Qube.make(
+                yield Qube.make_node(
                     key=key,
                     values=values,
                     children=from_dict(children),
                 )
 
-        return Qube.root_node(list(from_dict(d)))
+        return Qube.make_root(list(from_dict(d)))
 
     def to_dict(self) -> dict:
         def to_dict(q: Qube) -> tuple[str, dict]:
@@ -161,6 +194,13 @@ class Qube:
 
         return to_dict(self)[1]
 
+    @classmethod
+    def from_protobuf(cls, msg: bytes) -> Qube:
+        return proto_to_qube(cls, msg)
+
+    def to_protobuf(self) -> bytes:
+        return qube_to_proto(self)
+
     @classmethod
     def from_tree(cls, tree_str):
         lines = tree_str.splitlines()
@@ -214,17 +254,12 @@ class Qube:
 
     @classmethod
     def empty(cls) -> Qube:
-        return Qube.root_node([])
+        return Qube.make_root([])
 
     def __str_helper__(self, depth=None, name=None) -> str:
-        node = (
-            dataclasses.replace(
-                self,
-                data=RootNodeData(key=name, values=self.values, metadata=self.metadata),
-            )
-            if name is not None
-            else self
-        )
+        node = self
+        if name is not None:
+            node = node.replace(key=name)
         out = "".join(node_tree_to_string(node=node, depth=depth))
         if out[-1] == "\n":
             out = out[:-1]
@@ -239,16 +274,19 @@ class Qube:
     def print(self, depth=None, name: str | None = None):
         print(self.__str_helper__(depth=depth, name=name))
 
-    def html(self, depth=2, collapse=True, name: str | None = None) -> HTML:
-        node = (
-            dataclasses.replace(
-                self,
-                data=RootNodeData(key=name, values=self.values, metadata=self.metadata),
-            )
-            if name is not None
-            else self
+    def html(
+        self,
+        depth=2,
+        collapse=True,
+        name: str | None = None,
+        info: Callable[[Qube], str] | None = None,
+    ) -> HTML:
+        node = self
+        if name is not None:
+            node = node.replace(key=name)
+        return HTML(
+            node_tree_to_html(node=node, depth=depth, collapse=collapse, info=info)
         )
-        return HTML(node_tree_to_html(node=node, depth=depth, collapse=collapse))
 
     def _repr_html_(self) -> str:
         return node_tree_to_html(self, depth=2, collapse=True)
@@ -257,7 +295,7 @@ class Qube:
     def __rtruediv__(self, other: str) -> Qube:
         key, values = other.split("=")
         values_enum = QEnum((values.split("/")))
-        return Qube.root_node([Qube.make(key, values_enum, self.children)])
+        return Qube.make_root([Qube.make_node(key, values_enum, self.children)])
 
     def __or__(self, other: Qube) -> Qube:
         return set_operations.operation(
@@ -358,16 +396,16 @@ class Qube:
                     raise KeyError(
                         f"Key '{key}' not found in children of '{current.key}', available keys are {[c.key for c in current.children]}"
                     )
-            return Qube.root_node(current.children)
+            return Qube.make_root(current.children)
 
         elif isinstance(args, tuple) and len(args) == 2:
             key, value = args
             for c in self.children:
                 if c.key == key and value in c.values:
-                    return Qube.root_node(c.children)
-            raise KeyError(f"Key {key} not found in children of {self.key}")
+                    return Qube.make_root(c.children)
+            raise KeyError(f"Key '{key}' not found in children of {self.key}")
         else:
-            raise ValueError("Unknown key type")
+            raise ValueError(f"Unknown key type {args}")
 
     @cached_property
     def n_leaves(self) -> int:
@@ -410,7 +448,7 @@ class Qube:
             for c in node.children:
                 if c.key in _keys:
                     grandchildren = tuple(sorted(remove_key(cc) for cc in c.children))
-                    grandchildren = remove_key(Qube.root_node(grandchildren)).children
+                    grandchildren = remove_key(Qube.make_root(grandchildren)).children
                     children.extend(grandchildren)
                 else:
                     children.append(remove_key(c))
@@ -424,7 +462,7 @@ class Qube:
             if node.key in converters:
                 converter = converters[node.key]
                 values = [converter(v) for v in node.values]
-                new_node = node.replace(values=QEnum(values), dtype=type(values[0]))
+                new_node = node.replace(values=QEnum(values))
                 return new_node
             return node
 
@@ -516,7 +554,8 @@ class Qube:
 
             return node.replace(
                 children=new_children,
-                metadata=dict(self.metadata) | {"is_leaf": not bool(new_children)},
+                metadata=dict(self.metadata)
+                | ({"is_leaf": not bool(new_children)} if mode == "next_level" else {}),
             )
 
         return self.replace(
@@ -544,6 +583,26 @@ class Qube:
             axes[self.key].update(self.values)
         return dict(axes)
 
+    def axes_info(self, depth=0) -> dict[str, AxisInfo]:
+        axes = defaultdict(
+            lambda: AxisInfo(key="", type=str, depths=set(), values=set())
+        )
+        for c in self.children:
+            for k, info in c.axes_info(depth=depth + 1).items():
+                axes[k].combine(info)
+
+        if self.key != "root":
+            axes[self.key].combine(
+                AxisInfo(
+                    key=self.key,
+                    type=type(next(iter(self.values))),
+                    depths={depth},
+                    values=set(self.values),
+                )
+            )
+
+        return dict(axes)
+
     @cached_property
     def structural_hash(self) -> int:
         """
@@ -570,7 +629,7 @@ class Qube:
         """
 
         def union(a: Qube, b: Qube) -> Qube:
-            b = type(self).root_node(children=(b,))
+            b = type(self).make_root(children=(b,))
             out = set_operations.operation(
                 a, b, set_operations.SetOperation.UNION, type(self)
             )
@@ -583,3 +642,20 @@ class Qube:
             )
 
         return self.replace(children=tuple(sorted(new_children)))
+
+    def add_metadata(self, **kwargs: dict[str, Any]):
+        metadata = {
+            k: np.array(
+                [
+                    v,
+                ]
+            )
+            for k, v in kwargs.items()
+        }
+        return self.replace(metadata=metadata)
+
+    def strip_metadata(self) -> Qube:
+        def strip(node):
+            return node.replace(metadata=frozendict({}))
+
+        return self.transform(strip)
diff --git a/src/python/qubed/__init__.py b/src/python/qubed/__init__.py
index 399752b..00fc3a1 100644
--- a/src/python/qubed/__init__.py
+++ b/src/python/qubed/__init__.py
@@ -1,3 +1,4 @@
+from . import protobuf
 from .Qube import Qube
 
-__all__ = ["Qube"]
+__all__ = ["Qube", "protobuf"]
diff --git a/src/python/qubed/metadata.py b/src/python/qubed/metadata.py
index db460d5..05e37ee 100644
--- a/src/python/qubed/metadata.py
+++ b/src/python/qubed/metadata.py
@@ -18,7 +18,7 @@ def make_node(
     children: tuple[Qube, ...],
     metadata: dict[str, np.ndarray] | None = None,
 ):
-    return cls.make(
+    return cls.make_node(
         key=key,
         values=QEnum(values),
         metadata={k: np.array(v).reshape(shape) for k, v in metadata.items()}
@@ -39,5 +39,5 @@ def from_nodes(cls, nodes, add_root=True):
         root = make_node(cls, shape=shape, children=(root,), key=key, **info)
 
     if add_root:
-        return cls.root_node(children=(root,))
+        return cls.make_root(children=(root,))
     return root
diff --git a/src/python/qubed/node_types.py b/src/python/qubed/node_types.py
deleted file mode 100644
index 563d813..0000000
--- a/src/python/qubed/node_types.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from dataclasses import dataclass, field
-
-import numpy as np
-from frozendict import frozendict
-
-from .value_types import ValueGroup
-
-
-@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True)
-class NodeData:
-    key: str
-    values: ValueGroup
-    metadata: frozendict[str, np.ndarray] = field(
-        default_factory=lambda: frozendict({}), compare=False
-    )
-    dtype: type = str
-
-    def summary(self) -> str:
-        return f"{self.key}={self.values.summary()}" if self.key != "root" else "root"
-
-
-@dataclass(frozen=False, eq=True, order=True)
-class RootNodeData(NodeData):
-    "Helper class to print a custom root name"
-
-    def summary(self) -> str:
-        return self.key
diff --git a/src/python/qubed/protobuf/__init__.py b/src/python/qubed/protobuf/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/python/qubed/protobuf/adapters.py b/src/python/qubed/protobuf/adapters.py
new file mode 100644
index 0000000..4d89baa
--- /dev/null
+++ b/src/python/qubed/protobuf/adapters.py
@@ -0,0 +1,99 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+from frozendict import frozendict
+
+from ..value_types import QEnum
+from . import qube_pb2
+
+if TYPE_CHECKING:
+    from ..Qube import Qube
+
+
+def _ndarray_to_proto(arr: np.ndarray) -> qube_pb2.NdArray:
+    """np.ndarray → NdArray message"""
+    return qube_pb2.NdArray(
+        shape=list(arr.shape),
+        dtype=str(arr.dtype),
+        raw=arr.tobytes(order="C"),
+    )
+
+
+def _ndarray_from_proto(msg: qube_pb2.NdArray) -> np.ndarray:
+    """NdArray message → np.ndarray (immutable view)"""
+    return np.frombuffer(msg.raw, dtype=msg.dtype).reshape(tuple(msg.shape))
+
+
+def _py_to_valuegroup(value: list[str] | np.ndarray) -> qube_pb2.ValueGroup:
+    """Accept str-sequence *or* ndarray and return ValueGroup."""
+    vg = qube_pb2.ValueGroup()
+    if isinstance(value, np.ndarray):
+        vg.tensor.CopyFrom(_ndarray_to_proto(value))
+    else:
+        vg.s.items.extend(value)
+    return vg
+
+
+def _valuegroup_to_py(vg: qube_pb2.ValueGroup) -> list[str] | np.ndarray:
+    """ValueGroup → list[str]  *or* ndarray"""
+    arm = vg.WhichOneof("payload")
+    if arm == "tensor":
+        return _ndarray_from_proto(vg.tensor)
+
+    return QEnum(vg.s.items)
+
+
+def _py_to_metadatagroup(value: np.ndarray) -> qube_pb2.MetadataGroup:
+    """Accept str-sequence *or* ndarray and return ValueGroup."""
+    vg = qube_pb2.MetadataGroup()
+    if not isinstance(value, np.ndarray):
+        value = np.array([value])
+
+    vg.tensor.CopyFrom(_ndarray_to_proto(value))
+    return vg
+
+
+def _metadatagroup_to_py(vg: qube_pb2.MetadataGroup) -> np.ndarray:
+    """ValueGroup → list[str]  *or* ndarray"""
+    arm = vg.WhichOneof("payload")
+    if arm == "tensor":
+        return _ndarray_from_proto(vg.tensor)
+
+    raise ValueError(f"Unknown arm {arm}")
+
+
+def _qube_to_proto(q: Qube) -> qube_pb2.Qube:
+    """Frozen Qube dataclass → protobuf Qube message (new object)."""
+    return qube_pb2.Qube(
+        key=q.key,
+        values=_py_to_valuegroup(q.values),
+        metadata={k: _py_to_metadatagroup(v) for k, v in q.metadata.items()},
+        children=[_qube_to_proto(c) for c in q.children],
+        is_root=q.is_root,
+    )
+
+
+def qube_to_proto(q: Qube) -> bytes:
+    return _qube_to_proto(q).SerializeToString()
+
+
+def _proto_to_qube(cls: type, msg: qube_pb2.Qube) -> Qube:
+    """protobuf Qube message → frozen Qube dataclass (new object)."""
+
+    return cls(
+        key=msg.key,
+        values=_valuegroup_to_py(msg.values),
+        metadata=frozendict(
+            {k: _metadatagroup_to_py(v) for k, v in msg.metadata.items()}
+        ),
+        children=tuple(_proto_to_qube(cls, c) for c in msg.children),
+        is_root=msg.is_root,
+    )
+
+
+def proto_to_qube(cls: type, wire: bytes) -> Qube:
+    msg = qube_pb2.Qube()
+    msg.ParseFromString(wire)
+    return _proto_to_qube(cls, msg)
diff --git a/src/python/qubed/protobuf/qube_pb2.py b/src/python/qubed/protobuf/qube_pb2.py
new file mode 100644
index 0000000..6a5ea5c
--- /dev/null
+++ b/src/python/qubed/protobuf/qube_pb2.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# NO CHECKED-IN PROTOBUF GENCODE
+# source: qube.proto
+# Protobuf Python Version: 5.29.0
+"""Generated protocol buffer code."""
+
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import runtime_version as _runtime_version
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+_runtime_version.ValidateProtobufRuntimeVersion(
+    _runtime_version.Domain.PUBLIC, 5, 29, 0, "", "qube.proto"
+)
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+    b'\n\nqube.proto"4\n\x07NdArray\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x0b\n\x03raw\x18\x03 \x01(\x0c"\x1c\n\x0bStringGroup\x12\r\n\x05items\x18\x01 \x03(\t"N\n\nValueGroup\x12\x19\n\x01s\x18\x01 \x01(\x0b\x32\x0c.StringGroupH\x00\x12\x1a\n\x06tensor\x18\x02 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"6\n\rMetadataGroup\x12\x1a\n\x06tensor\x18\x01 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"\xd1\x01\n\x04Qube\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1b\n\x06values\x18\x02 \x01(\x0b\x32\x0b.ValueGroup\x12%\n\x08metadata\x18\x03 \x03(\x0b\x32\x13.Qube.MetadataEntry\x12\r\n\x05\x64type\x18\x04 \x01(\t\x12\x17\n\x08\x63hildren\x18\x05 \x03(\x0b\x32\x05.Qube\x12\x0f\n\x07is_root\x18\x06 \x01(\x08\x1a?\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1d\n\x05value\x18\x02 \x01(\x0b\x32\x0e.MetadataGroup:\x02\x38\x01\x62\x06proto3'
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "qube_pb2", _globals)
+if not _descriptor._USE_C_DESCRIPTORS:
+    DESCRIPTOR._loaded_options = None
+    _globals["_QUBE_METADATAENTRY"]._loaded_options = None
+    _globals["_QUBE_METADATAENTRY"]._serialized_options = b"8\001"
+    _globals["_NDARRAY"]._serialized_start = 14
+    _globals["_NDARRAY"]._serialized_end = 66
+    _globals["_STRINGGROUP"]._serialized_start = 68
+    _globals["_STRINGGROUP"]._serialized_end = 96
+    _globals["_VALUEGROUP"]._serialized_start = 98
+    _globals["_VALUEGROUP"]._serialized_end = 176
+    _globals["_METADATAGROUP"]._serialized_start = 178
+    _globals["_METADATAGROUP"]._serialized_end = 232
+    _globals["_QUBE"]._serialized_start = 235
+    _globals["_QUBE"]._serialized_end = 444
+    _globals["_QUBE_METADATAENTRY"]._serialized_start = 381
+    _globals["_QUBE_METADATAENTRY"]._serialized_end = 444
+# @@protoc_insertion_point(module_scope)
diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py
index 84f0b6e..ecb2ad5 100644
--- a/src/python/qubed/set_operations.py
+++ b/src/python/qubed/set_operations.py
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Iterable
 import numpy as np
 from frozendict import frozendict
 
-from .node_types import NodeData
 from .value_types import QEnum, ValueGroup, WildcardGroup
 
 if TYPE_CHECKING:
@@ -27,7 +26,7 @@ class SetOperation(Enum):
 @dataclass(eq=True, frozen=True)
 class ValuesMetadata:
     values: ValueGroup
-    metadata: dict[str, np.ndarray]
+    indices: list[int] | slice
 
 
 def QEnum_intersection(
@@ -49,19 +48,17 @@ def QEnum_intersection(
 
     intersection_out = ValuesMetadata(
         values=QEnum(list(intersection.keys())),
-        metadata={
-            k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
-        },
+        indices=list(intersection.values()),
     )
 
     just_A_out = ValuesMetadata(
         values=QEnum(list(just_A.keys())),
-        metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()},
+        indices=list(just_A.values()),
     )
 
     just_B_out = ValuesMetadata(
         values=QEnum(list(just_B.keys())),
-        metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()},
+        indices=list(just_B.values()),
     )
 
     return just_A_out, intersection_out, just_B_out
@@ -76,61 +73,107 @@ def node_intersection(
 
     if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
         return (
-            ValuesMetadata(QEnum([]), {}),
-            ValuesMetadata(WildcardGroup(), {}),
-            ValuesMetadata(QEnum([]), {}),
+            ValuesMetadata(QEnum([]), []),
+            ValuesMetadata(WildcardGroup(), slice(None)),
+            ValuesMetadata(QEnum([]), []),
         )
 
     # If A is a wildcard matcher then the intersection is everything
     # just_A is still *
     # just_B is empty
     if isinstance(A.values, WildcardGroup):
-        return A, B, ValuesMetadata(QEnum([]), {})
+        return A, B, ValuesMetadata(QEnum([]), [])
 
     # The reverse if B is a wildcard
     if isinstance(B.values, WildcardGroup):
-        return ValuesMetadata(QEnum([]), {}), A, B
+        return ValuesMetadata(QEnum([]), []), A, B
 
     raise NotImplementedError(
         f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
     )
 
 
-def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube | None:
+def operation(
+    A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0
+) -> Qube | None:
     assert A.key == B.key, (
         "The two Qube root nodes must have the same key to perform set operations,"
         f"would usually be two root nodes. They have {A.key} and {B.key} respectively"
     )
+    node_key = A.key
+
+    assert A.is_root == B.is_root
+    is_root = A.is_root
 
     assert A.values == B.values, (
         f"The two Qube root nodes must have the same values to perform set operations {A.values = }, {B.values = }"
     )
+    node_values = A.values
 
     # Group the children of the two nodes by key
     nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict(
         lambda: ([], [])
     )
-    for node in A.children:
-        nodes_by_key[node.key][0].append(node)
-    for node in B.children:
-        nodes_by_key[node.key][1].append(node)
-
     new_children: list[Qube] = []
 
+    # Sort out metadata into what can stay at this level and what must move down
+    stayput_metadata: dict[str, np.ndarray] = {}
+    pushdown_metadata_A: dict[str, np.ndarray] = {}
+    pushdown_metadata_B: dict[str, np.ndarray] = {}
+    for key in set(A.metadata.keys()) | set(B.metadata.keys()):
+        if key not in A.metadata:
+            raise ValueError(f"B has key {key} but A does not. {A = } {B = }")
+        if key not in B.metadata:
+            raise ValueError(f"A has key {key} but B does not. {A = } {B = }")
+
+        print(f"{key = } {A.metadata[key] = } {B.metadata[key]}")
+        A_val = A.metadata[key]
+        B_val = B.metadata[key]
+        if A_val == B_val:
+            print(f"{'  ' * depth}Keeping metadata key '{key}' at this level")
+            stayput_metadata[key] = A.metadata[key]
+        else:
+            print(f"{'  ' * depth}Pushing down metadata key '{key}' {A_val} {B_val}")
+            pushdown_metadata_A[key] = A_val
+            pushdown_metadata_B[key] = B_val
+
+    # Add all the metadata that needs to be pushed down to the child nodes
+    # When pushing down the metadata we need to account for the fact it now affects more values
+    # So expand the metadata entries from shape (a, b, ..., c) to (a, b, ..., c, d)
+    # where d is the length of the node values
+    for node in A.children:
+        N = len(node.values)
+        print(N)
+        meta = {
+            k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,))
+            for k, v in pushdown_metadata_A.items()
+        }
+        node = node.replace(metadata=node.metadata | meta)
+        nodes_by_key[node.key][0].append(node)
+
+    for node in B.children:
+        N = len(node.values)
+        meta = {
+            k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,))
+            for k, v in pushdown_metadata_B.items()
+        }
+        node = node.replace(metadata=node.metadata | meta)
+        nodes_by_key[node.key][1].append(node)
+
     # For every node group, perform the set operation
     for key, (A_nodes, B_nodes) in nodes_by_key.items():
-        output = list(_operation(key, A_nodes, B_nodes, operation_type, node_type))
+        output = list(
+            _operation(key, A_nodes, B_nodes, operation_type, node_type, depth + 1)
+        )
+        # print(f"{'  '*depth}_operation {operation_type.name} {A_nodes} {B_nodes} out = [{output}]")
         new_children.extend(output)
 
-    # print(f"operation {operation_type}: {A}, {B} {new_children = }")
-    # print(f"{A.children = }")
-    # print(f"{B.children = }")
-    # print(f"{new_children = }")
+    # print(f"{'  '*depth}operation {operation_type.name} [{A}] [{B}] new_children = [{new_children}]")
 
     # If there are now no children as a result of the operation, return nothing.
     if (A.children or B.children) and not new_children:
         if A.key == "root":
-            return A.replace(children=())
+            return node_type.make_root(children=())
         else:
             return None
 
@@ -140,20 +183,34 @@ def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube
     new_children = list(compress_children(new_children))
 
     # The values and key are the same so we just replace the children
-    return A.replace(children=tuple(sorted(new_children)))
+    return node_type.make_node(
+        key=node_key,
+        values=node_values,
+        children=new_children,
+        metadata=stayput_metadata,
+        is_root=is_root,
+    )
+
+
+def get_indices(metadata: dict[str, np.ndarray], indices: list[int] | slice):
+    return {k: v[..., indices] for k, v in metadata.items()}
 
 
-# The root node is special so we need a helper method that we can recurse on
 def _operation(
-    key: str, A: list[Qube], B: list[Qube], operation_type: SetOperation, node_type
+    key: str,
+    A: list[Qube],
+    B: list[Qube],
+    operation_type: SetOperation,
+    node_type,
+    depth: int,
 ) -> Iterable[Qube]:
     keep_just_A, keep_intersection, keep_just_B = operation_type.value
 
-    # Iterate over all pairs (node_A, node_B)
     values = {}
     for node in A + B:
         values[node] = ValuesMetadata(node.values, node.metadata)
 
+    # Iterate over all pairs (node_A, node_B)
     for node_a in A:
         for node_b in B:
             # Compute A - B, A & B, B - A
@@ -171,17 +228,21 @@ def _operation(
                 if intersection.values:
                     new_node_a = node_a.replace(
                         values=intersection.values,
-                        metadata=intersection.metadata,
+                        metadata=get_indices(node_a.metadata, intersection.indices),
                     )
                     new_node_b = node_b.replace(
                         values=intersection.values,
-                        metadata=intersection.metadata,
+                        metadata=get_indices(node_b.metadata, intersection.indices),
                     )
-                    # print(f"{node_a = }")
-                    # print(f"{node_b = }")
-                    # print(f"{intersection.values =}")
+                    # print(f"{' '*depth}{node_a = }")
+                    # print(f"{' '*depth}{node_b = }")
+                    # print(f"{' '*depth}{intersection.values =}")
                     result = operation(
-                        new_node_a, new_node_b, operation_type, node_type
+                        new_node_a,
+                        new_node_b,
+                        operation_type,
+                        node_type,
+                        depth=depth + 1,
                     )
                     if result is not None:
                         yield result
@@ -190,20 +251,20 @@ def _operation(
     if keep_just_A:
         for node in A:
             if values[node].values:
-                yield node_type.make(
+                yield node_type.make_node(
                     key,
                     children=node.children,
                     values=values[node].values,
-                    metadata=values[node].metadata,
+                    metadata=get_indices(node.metadata, values[node].indices),
                 )
     if keep_just_B:
         for node in B:
             if values[node].values:
-                yield node_type.make(
+                yield node_type.make_node(
                     key,
                     children=node.children,
                     values=values[node].values,
-                    metadata=values[node].metadata,
+                    metadata=get_indices(node.metadata, values[node].indices),
                 )
 
 
@@ -230,7 +291,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
             key = child_list[0].key
 
             # Compress the children into a single node
-            assert all(isinstance(child.data.values, QEnum) for child in child_list), (
+            assert all(isinstance(child.values, QEnum) for child in child_list), (
                 "All children must have QEnum values"
             )
 
@@ -241,19 +302,19 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
 
             metadata: frozendict[str, np.ndarray] = frozendict(
                 {
-                    k: np.concatenate(metadata_group, axis=0)
+                    k: np.concatenate(metadata_group, axis=-1)
                     for k, metadata_group in metadata_groups.items()
                 }
             )
 
-            node_data = NodeData(
-                key=key,
-                metadata=metadata,
-                values=QEnum(set(v for child in child_list for v in child.data.values)),
-            )
             children = [cc for c in child_list for cc in c.children]
             compressed_children = compress_children(children)
-            new_child = node_type(data=node_data, children=compressed_children)
+            new_child = node_type.make_node(
+                key=key,
+                metadata=metadata,
+                values=QEnum(set(v for child in child_list for v in child.values)),
+                children=compressed_children,
+            )
         else:
             # If the group is size one just keep it
             new_child = child_list.pop()
diff --git a/src/python/qubed/tree_formatters.py b/src/python/qubed/tree_formatters.py
index 49ead5a..cabd7dd 100644
--- a/src/python/qubed/tree_formatters.py
+++ b/src/python/qubed/tree_formatters.py
@@ -2,9 +2,8 @@ from __future__ import annotations
 
 import random
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Iterable
+from typing import TYPE_CHECKING, Callable, Iterable
 
-import numpy as np
 
 if TYPE_CHECKING:
     from .Qube import Qube
@@ -71,27 +70,38 @@ def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[st
 
 
 def summarize_node_html(
-    node: Qube, collapse=False, max_summary_length=50, **kwargs
+    node: Qube,
+    collapse=False,
+    max_summary_length=50,
+    info: Callable[[Qube], str] | None = None,
+    **kwargs,
 ) -> tuple[str, Qube]:
     """
     Extracts a summarized representation of the node while collapsing single-child paths.
     Returns the summary string and the last node in the chain that has multiple children.
     """
+    if info is None:
+
+        def info_func(node: Qube, /):
+            return (
+                # f"dtype: {node.dtype}\n"
+                f"metadata: {dict(node.metadata)}\n"
+            )
+    else:
+        info_func = info
+
     summaries = []
 
     while True:
         path = node.summary(**kwargs)
         summary = path
-        if "is_leaf" in node.metadata and node.metadata["is_leaf"]:
-            summary += " 🌿"
 
         if len(summary) > max_summary_length:
             summary = summary[:max_summary_length] + "..."
-        info = (
-            f"dtype: {node.dtype.__name__}\n"
-            f"metadata: {dict((k, np.shape(v)) for k, v in node.metadata.items())}\n"
-        )
-        summary = f'<span class="qubed-node" data-path="{path}" title="{info}">{summary}</span>'
+
+        info_string = info_func(node)
+
+        summary = f'<span class="qubed-node" data-path="{path}" title="{info_string}">{summary}</span>'
         summaries.append(summary)
         if not collapse:
             break
@@ -105,9 +115,14 @@ def summarize_node_html(
 
 
 def _node_tree_to_html(
-    node: Qube, prefix: str = "", depth=1, connector="", **kwargs
+    node: Qube,
+    prefix: str = "",
+    depth=1,
+    connector="",
+    info: Callable[[Qube], str] | None = None,
+    **kwargs,
 ) -> Iterable[str]:
-    summary, node = summarize_node_html(node, **kwargs)
+    summary, node = summarize_node_html(node, info=info, **kwargs)
 
     if len(node.children) == 0:
         yield f'<span class="qubed-level">{connector}{summary}</span>'
@@ -124,13 +139,20 @@ def _node_tree_to_html(
             prefix + extension,
             depth=depth - 1,
             connector=prefix + connector,
+            info=info,
             **kwargs,
         )
     yield "</details>"
 
 
 def node_tree_to_html(
-    node: Qube, depth=1, include_css=True, include_js=True, css_id=None, **kwargs
+    node: Qube,
+    depth=1,
+    include_css=True,
+    include_js=True,
+    css_id=None,
+    info: Callable[[Qube], str] | None = None,
+    **kwargs,
 ) -> str:
     if css_id is None:
         css_id = f"qubed-tree-{random.randint(0, 1000000)}"
@@ -215,5 +237,5 @@ def node_tree_to_html(
         nodes.forEach(n => n.addEventListener("click", nodeOnClick));
         </script>
         """.replace("CSS_ID", css_id)
-    nodes = "".join(_node_tree_to_html(node=node, depth=depth, **kwargs))
+    nodes = "".join(_node_tree_to_html(node=node, depth=depth, info=info, **kwargs))
     return f"{js if include_js else ''}{css if include_css else ''}<pre class='qubed-tree' id='{css_id}'>{nodes}</pre>"
diff --git a/src/python/qubed/value_types.py b/src/python/qubed/value_types.py
index ad38af6..e72f593 100644
--- a/src/python/qubed/value_types.py
+++ b/src/python/qubed/value_types.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import dataclasses
 from abc import ABC, abstractmethod
-from dataclasses import dataclass, replace
+from dataclasses import dataclass
 from datetime import date, datetime, timedelta
 from typing import (
     TYPE_CHECKING,
@@ -21,6 +21,11 @@ if TYPE_CHECKING:
 
 @dataclass(frozen=True)
 class ValueGroup(ABC):
+    @abstractmethod
+    def dtype(self) -> str:
+        "Provide a string rep of the datatype of these values"
+        pass
+
     @abstractmethod
     def summary(self) -> str:
         "Provide a string summary of the value group."
@@ -69,9 +74,13 @@ class QEnum(ValueGroup):
     """
 
     values: EnumValuesType
+    _dtype: str = "str"
 
     def __init__(self, obj):
         object.__setattr__(self, "values", tuple(sorted(obj)))
+        object.__setattr__(
+            self, "dtype", type(self.values[0]) if len(self.values) > 0 else "str"
+        )
 
     def __post_init__(self):
         assert isinstance(self.values, tuple)
@@ -88,6 +97,9 @@ class QEnum(ValueGroup):
     def __contains__(self, value: Any) -> bool:
         return value in self.values
 
+    def dtype(self):
+        return self._dtype
+
     @classmethod
     def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
         return [cls(tuple(values))]
@@ -114,7 +126,7 @@ class WildcardGroup(ValueGroup):
         return "*"
 
     def __len__(self):
-        return None
+        return 1
 
     def __iter__(self):
         return ["*"]
@@ -122,6 +134,9 @@ class WildcardGroup(ValueGroup):
     def __bool__(self):
         return True
 
+    def dtype(self):
+        return "*"
+
     @classmethod
     def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
         return [WildcardGroup()]
@@ -398,7 +413,7 @@ def convert_datatypes(q: "Qube", conversions: dict[str, ValueGroup]) -> "Qube":
             )
             for values_group in data_type.from_strings(q.values):
                 # print(values_group)
-                yield replace(q, data=replace(q.data, values=values_group))
+                yield q.replace(values=values_group)
         else:
             yield q
 
diff --git a/src/qube.proto b/src/qube.proto
new file mode 100644
index 0000000..bb2c685
--- /dev/null
+++ b/src/qube.proto
@@ -0,0 +1,32 @@
+syntax = "proto3";
+
+message NdArray {
+  repeated int64 shape = 1;
+  string  dtype  = 2;
+  bytes   raw    = 3;
+}
+
+message StringGroup {repeated string items = 1; }
+
+// Stores values i.e class=1/2/3 the 1/2/3 part
+message ValueGroup {
+    oneof payload {
+    StringGroup s = 1;
+    NdArray tensor = 2;
+  }
+}
+
+message MetadataGroup {
+    oneof payload {
+    NdArray tensor = 1;
+  }
+}
+
+message Qube {
+  string key          = 1;
+  ValueGroup values   = 2;
+  map<string, MetadataGroup> metadata = 3;
+  string dtype        = 4;
+  repeated Qube children = 5;
+  bool is_root        = 6;
+}
diff --git a/stac_server/main.py b/stac_server/main.py
index 1c6680a..83eeaf9 100644
--- a/stac_server/main.py
+++ b/stac_server/main.py
@@ -176,7 +176,7 @@ def follow_query(request: dict[str, str | list[str]], qube: Qube):
     by_path = defaultdict(lambda: {"paths": set(), "values": set()})
 
     for request, node in s.leaf_nodes():
-        if not node.data.metadata["is_leaf"]:
+        if not node.metadata["is_leaf"]:
             by_path[node.key]["values"].update(node.values.values)
             by_path[node.key]["paths"].add(frozendict(request))
 
diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py
index 0f4fb81..fb9d8e0 100644
--- a/tests/test_basic_operations.py
+++ b/tests/test_basic_operations.py
@@ -1,36 +1,40 @@
 from qubed import Qube
 
-d = {
-    "class=od": {
-        "expver=0001": {"param=1": {}, "param=2": {}},
-        "expver=0002": {"param=1": {}, "param=2": {}},
-    },
-    "class=rd": {
-        "expver=0001": {"param=1": {}, "param=2": {}, "param=3": {}},
-        "expver=0002": {"param=1": {}, "param=2": {}},
-    },
-}
-q = Qube.from_dict(d)
-
-
-def test_eq():
-    r = Qube.from_dict(d)
-    assert q == r
+q = Qube.from_tree("""
+root
+├── class=od
+│   ├── expver=0001
+│   │   ├── param=1
+│   │   └── param=2
+│   └── expver=0002
+│       ├── param=1
+│       └── param=2
+└── class=rd
+    ├── expver=0001
+    │   ├── param=1
+    │   ├── param=2
+    │   └── param=3
+    └── expver=0002
+        ├── param=1
+        └── param=2
+""")
 
 
 def test_getitem():
-    assert q["class", "od"] == Qube.from_dict(
-        {
-            "expver=0001": {"param=1": {}, "param=2": {}},
-            "expver=0002": {"param=1": {}, "param=2": {}},
-        }
-    )
-    assert q["class", "od"]["expver", "0001"] == Qube.from_dict(
-        {
-            "param=1": {},
-            "param=2": {},
-        }
-    )
+    assert q["class", "od"] == Qube.from_tree("""
+root
+├── expver=0001
+│   ├── param=1
+│   └── param=2
+└── expver=0002
+    ├── param=1
+    └── param=2
+""")
+
+    assert q["class", "od"]["expver", "0001"] == Qube.from_tree("""
+root
+├── param=1
+└── param=2""")
 
 
 def test_n_leaves():
diff --git a/tests/test_conversions.py b/tests/test_conversions.py
index c37117a..16ed654 100644
--- a/tests/test_conversions.py
+++ b/tests/test_conversions.py
@@ -2,7 +2,7 @@ from qubed import Qube
 
 
 def test_json_round_trip():
-    u = Qube.from_dict(
+    from_dict = Qube.from_dict(
         {
             "class=d1": {
                 "dataset=climate-dt/weather-dt": {
@@ -14,5 +14,54 @@ def test_json_round_trip():
             }
         }
     )
-    json = u.to_json()
-    assert Qube.from_json(json) == u
+
+    from_tree = Qube.from_tree("""
+    root, class=d1
+    ├── dataset=another-value, generation=1/2/3
+    └── dataset=climate-dt/weather-dt, generation=1/2/3/4
+    """)
+
+    from_json = Qube.from_json(
+        {
+            "key": "root",
+            "values": ["root"],
+            "metadata": {},
+            "children": [
+                {
+                    "key": "class",
+                    "values": ["d1"],
+                    "metadata": {},
+                    "children": [
+                        {
+                            "key": "dataset",
+                            "values": ["another-value"],
+                            "metadata": {},
+                            "children": [
+                                {
+                                    "key": "generation",
+                                    "values": ["1", "2", "3"],
+                                    "metadata": {},
+                                    "children": [],
+                                }
+                            ],
+                        },
+                        {
+                            "key": "dataset",
+                            "values": ["climate-dt", "weather-dt"],
+                            "metadata": {},
+                            "children": [
+                                {
+                                    "key": "generation",
+                                    "values": ["1", "2", "3", "4"],
+                                    "metadata": {},
+                                    "children": [],
+                                }
+                            ],
+                        },
+                    ],
+                }
+            ],
+        }
+    )
+    assert from_tree == from_json
+    assert from_tree == from_dict
diff --git a/tests/test_formatters.py b/tests/test_formatters.py
index 7278927..571503e 100644
--- a/tests/test_formatters.py
+++ b/tests/test_formatters.py
@@ -20,14 +20,6 @@ root
     └── expver=0002, param=1/2
 """.strip()
 
-as_html = """
-<details open><summary class="qubed-level"><span class="qubed-node" data-path="root" title="dtype: str\nmetadata: {}\n">root</span></summary><span class="qubed-level">├── <span class="qubed-node" data-path="class=od" title="dtype: str\nmetadata: {}\n">class=od</span>, <span class="qubed-node" data-path="expver=0001/0002" title="dtype: str\nmetadata: {}\n">expver=0001/0002</span>, <span class="qubed-node" data-path="param=1/2" title="dtype: str\nmetadata: {}\n">param=1/2</span></span><details open><summary class="qubed-level">└── <span class="qubed-node" data-path="class=rd" title="dtype: str\nmetadata: {}\n">class=rd</span></summary><span class="qubed-level">    ├── <span class="qubed-node" data-path="expver=0001" title="dtype: str\nmetadata: {}\n">expver=0001</span>, <span class="qubed-node" data-path="param=1/2/3" title="dtype: str\nmetadata: {}\n">param=1/2/3</span></span><span class="qubed-level">    └── <span class="qubed-node" data-path="expver=0002" title="dtype: str\nmetadata: {}\n">expver=0002</span>, <span class="qubed-node" data-path="param=1/2" title="dtype: str\nmetadata: {}\n">param=1/2</span></span></details></details>
-""".strip()
-
 
 def test_string():
     assert str(q).strip() == as_string
-
-
-def test_html():
-    assert as_html in q._repr_html_()
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 6d9e416..ceaedf9 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -1,45 +1,44 @@
 from frozendict import frozendict
-from qubed import Qube
 
 
 def make_set(entries):
     return set((frozendict(a), frozendict(b)) for a, b in entries)
 
 
-def test_simple_union():
-    q = Qube.from_nodes(
-        {
-            "class": dict(values=["od", "rd"]),
-            "expver": dict(values=[1, 2]),
-            "stream": dict(
-                values=["a", "b", "c"], metadata=dict(number=list(range(12)))
-            ),
-        }
-    )
+# def test_simple_union():
+#     q = Qube.from_nodes(
+#         {
+#             "class": dict(values=["od", "rd"]),
+#             "expver": dict(values=[1, 2]),
+#             "stream": dict(
+#                 values=["a", "b", "c"], metadata=dict(number=list(range(12)))
+#             ),
+#         }
+#     )
 
-    r = Qube.from_nodes(
-        {
-            "class": dict(values=["xd"]),
-            "expver": dict(values=[1, 2]),
-            "stream": dict(
-                values=["a", "b", "c"], metadata=dict(number=list(range(12, 18)))
-            ),
-        }
-    )
+#     r = Qube.from_nodes(
+#         {
+#             "class": dict(values=["xd"]),
+#             "expver": dict(values=[1, 2]),
+#             "stream": dict(
+#                 values=["a", "b", "c"], metadata=dict(number=list(range(12, 18)))
+#             ),
+#         }
+#     )
 
-    expected_union = Qube.from_nodes(
-        {
-            "class": dict(values=["od", "rd", "xd"]),
-            "expver": dict(values=[1, 2]),
-            "stream": dict(
-                values=["a", "b", "c"], metadata=dict(number=list(range(18)))
-            ),
-        }
-    )
+#     expected_union = Qube.from_nodes(
+#         {
+#             "class": dict(values=["od", "rd", "xd"]),
+#             "expver": dict(values=[1, 2]),
+#             "stream": dict(
+#                 values=["a", "b", "c"], metadata=dict(number=list(range(18)))
+#             ),
+#         }
+#     )
 
-    union = q | r
+#     union = q | r
 
-    assert union == expected_union
-    assert make_set(expected_union.leaves_with_metadata()) == make_set(
-        union.leaves_with_metadata()
-    )
+#     assert union == expected_union
+#     assert make_set(expected_union.leaves_with_metadata()) == make_set(
+#         union.leaves_with_metadata()
+#     )
diff --git a/tests/test_protobuf.py b/tests/test_protobuf.py
new file mode 100644
index 0000000..34d0442
--- /dev/null
+++ b/tests/test_protobuf.py
@@ -0,0 +1,12 @@
+from qubed import Qube
+
+
+def test_protobuf_simple():
+    q = Qube.from_tree("""
+    root, class=d1
+    ├── dataset=another-value, generation=1/2/3
+    └── dataset=climate-dt/weather-dt, generation=1/2/3/4
+    """)
+    wire = q.to_protobuf()
+    round_trip = Qube.from_protobuf(wire)
+    assert round_trip == q