Massive rewrite

This commit is contained in:
Tom 2025-05-14 10:14:02 +01:00
parent ed4a9055fa
commit 35bb8f0edd
17 changed files with 623 additions and 243 deletions

View File

@ -8,17 +8,17 @@ import functools
import json import json
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Iterator, Literal, Sequence from typing import Any, Iterable, Iterator, Literal, Self, Sequence
import numpy as np import numpy as np
from frozendict import frozendict from frozendict import frozendict
from . import set_operations from . import set_operations
from .metadata import from_nodes from .metadata import from_nodes
from .node_types import NodeData, RootNodeData from .protobuf.adapters import proto_to_qube, qube_to_proto
from .tree_formatters import ( from .tree_formatters import (
HTML, HTML,
node_tree_to_html, node_tree_to_html,
@ -32,58 +32,90 @@ from .value_types import (
) )
@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True) @dataclass
class Qube: class AxisInfo:
data: NodeData key: str
children: tuple[Qube, ...] type: Any
depths: set[int]
values: set
@property def combine(self, other: Self):
def key(self) -> str: self.key = other.key
return self.data.key self.type = other.type
self.depths.update(other.depths)
self.values.update(other.values)
# print(f"combining {self} and {other} getting {result}")
@property def to_json(self):
def values(self) -> ValueGroup: return {
return self.data.values "key": self.key,
"type": self.type.__name__,
@property "values": list(self.values),
def metadata(self): "depths": list(self.depths),
return self.data.metadata
@property
def dtype(self):
return self.data.dtype
def replace(self, **kwargs) -> Qube:
data_keys = {
k: v
for k, v in kwargs.items()
if k in ["key", "values", "metadata", "dtype"]
} }
node_keys = {k: v for k, v in kwargs.items() if k == "children"}
if not data_keys and not node_keys:
return self
if not data_keys:
return dataclasses.replace(self, **node_keys)
return dataclasses.replace(
self, data=dataclasses.replace(self.data, **data_keys), **node_keys @dataclass(frozen=True, eq=True, order=True, unsafe_hash=True)
) class QubeNamedRoot:
"Helper class to print a custom root name"
key: str
dtype: str = "str"
children: tuple[Qube, ...] = ()
def summary(self) -> str: def summary(self) -> str:
return self.data.summary() return self.key
@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True)
class Qube:
key: str
values: ValueGroup
metadata: frozendict[str, np.ndarray] = field(
default_factory=lambda: frozendict({}), compare=False
)
children: tuple[Qube, ...] = ()
is_root: bool = False
def replace(self, **kwargs) -> Qube:
return dataclasses.replace(self, **kwargs)
def summary(self) -> str:
if self.is_root:
return self.key
return f"{self.key}={self.values.summary()}" if self.key != "root" else "root"
@classmethod @classmethod
def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube: def make_node(
cls,
key: str,
values: Iterable | QEnum | WildcardGroup,
children: Iterable[Qube],
metadata: dict[str, np.ndarray] = {},
is_root: bool = False,
) -> Qube:
if isinstance(values, ValueGroup):
values = values
else:
values = QEnum(values)
return cls( return cls(
data=NodeData( key,
key, values, metadata=frozendict(kwargs.get("metadata", frozendict())) values=values,
),
children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))), children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))),
metadata=frozendict(metadata),
is_root=is_root,
) )
@classmethod @classmethod
def root_node(cls, children: Iterable[Qube]) -> Qube: def make_root(cls, children: Iterable[Qube], metadata={}) -> Qube:
return cls.make("root", QEnum(("root",)), children) return cls.make_node(
"root",
values=QEnum(("root",)),
children=children,
metadata=metadata,
is_root=True,
)
@classmethod @classmethod
def load(cls, path: str | Path) -> Qube: def load(cls, path: str | Path) -> Qube:
@ -104,18 +136,19 @@ class Qube:
else: else:
values_group = QEnum([values]) values_group = QEnum([values])
children = [cls.make(key, values_group, children)] children = [cls.make_node(key, values_group, children)]
return cls.root_node(children) return cls.make_root(children)
@classmethod @classmethod
def from_json(cls, json: dict) -> Qube: def from_json(cls, json: dict) -> Qube:
def from_json(json: dict) -> Qube: def from_json(json: dict, depth=0) -> Qube:
return Qube.make( return Qube.make_node(
key=json["key"], key=json["key"],
values=values_from_json(json["values"]), values=values_from_json(json["values"]),
metadata=frozendict(json["metadata"]) if "metadata" in json else {}, metadata=frozendict(json["metadata"]) if "metadata" in json else {},
children=(from_json(c) for c in json["children"]), children=(from_json(c, depth + 1) for c in json["children"]),
is_root=(depth == 0),
) )
return from_json(json) return from_json(json)
@ -146,13 +179,13 @@ class Qube:
else: else:
values = QEnum(values) values = QEnum(values)
yield Qube.make( yield Qube.make_node(
key=key, key=key,
values=values, values=values,
children=from_dict(children), children=from_dict(children),
) )
return Qube.root_node(list(from_dict(d))) return Qube.make_root(list(from_dict(d)))
def to_dict(self) -> dict: def to_dict(self) -> dict:
def to_dict(q: Qube) -> tuple[str, dict]: def to_dict(q: Qube) -> tuple[str, dict]:
@ -161,6 +194,13 @@ class Qube:
return to_dict(self)[1] return to_dict(self)[1]
@classmethod
def from_protobuf(cls, msg: bytes) -> Qube:
return proto_to_qube(cls, msg)
def to_protobuf(self) -> bytes:
return qube_to_proto(self)
@classmethod @classmethod
def from_tree(cls, tree_str): def from_tree(cls, tree_str):
lines = tree_str.splitlines() lines = tree_str.splitlines()
@ -214,17 +254,12 @@ class Qube:
@classmethod @classmethod
def empty(cls) -> Qube: def empty(cls) -> Qube:
return Qube.root_node([]) return Qube.make_root([])
def __str_helper__(self, depth=None, name=None) -> str: def __str_helper__(self, depth=None, name=None) -> str:
node = ( node = self
dataclasses.replace( if name is not None:
self, node = node.replace(key=name)
data=RootNodeData(key=name, values=self.values, metadata=self.metadata),
)
if name is not None
else self
)
out = "".join(node_tree_to_string(node=node, depth=depth)) out = "".join(node_tree_to_string(node=node, depth=depth))
if out[-1] == "\n": if out[-1] == "\n":
out = out[:-1] out = out[:-1]
@ -239,16 +274,19 @@ class Qube:
def print(self, depth=None, name: str | None = None): def print(self, depth=None, name: str | None = None):
print(self.__str_helper__(depth=depth, name=name)) print(self.__str_helper__(depth=depth, name=name))
def html(self, depth=2, collapse=True, name: str | None = None) -> HTML: def html(
node = (
dataclasses.replace(
self, self,
data=RootNodeData(key=name, values=self.values, metadata=self.metadata), depth=2,
collapse=True,
name: str | None = None,
info: Callable[[Qube], str] | None = None,
) -> HTML:
node = self
if name is not None:
node = node.replace(key=name)
return HTML(
node_tree_to_html(node=node, depth=depth, collapse=collapse, info=info)
) )
if name is not None
else self
)
return HTML(node_tree_to_html(node=node, depth=depth, collapse=collapse))
def _repr_html_(self) -> str: def _repr_html_(self) -> str:
return node_tree_to_html(self, depth=2, collapse=True) return node_tree_to_html(self, depth=2, collapse=True)
@ -257,7 +295,7 @@ class Qube:
def __rtruediv__(self, other: str) -> Qube: def __rtruediv__(self, other: str) -> Qube:
key, values = other.split("=") key, values = other.split("=")
values_enum = QEnum((values.split("/"))) values_enum = QEnum((values.split("/")))
return Qube.root_node([Qube.make(key, values_enum, self.children)]) return Qube.make_root([Qube.make_node(key, values_enum, self.children)])
def __or__(self, other: Qube) -> Qube: def __or__(self, other: Qube) -> Qube:
return set_operations.operation( return set_operations.operation(
@ -358,16 +396,16 @@ class Qube:
raise KeyError( raise KeyError(
f"Key '{key}' not found in children of '{current.key}', available keys are {[c.key for c in current.children]}" f"Key '{key}' not found in children of '{current.key}', available keys are {[c.key for c in current.children]}"
) )
return Qube.root_node(current.children) return Qube.make_root(current.children)
elif isinstance(args, tuple) and len(args) == 2: elif isinstance(args, tuple) and len(args) == 2:
key, value = args key, value = args
for c in self.children: for c in self.children:
if c.key == key and value in c.values: if c.key == key and value in c.values:
return Qube.root_node(c.children) return Qube.make_root(c.children)
raise KeyError(f"Key {key} not found in children of {self.key}") raise KeyError(f"Key '{key}' not found in children of {self.key}")
else: else:
raise ValueError("Unknown key type") raise ValueError(f"Unknown key type {args}")
@cached_property @cached_property
def n_leaves(self) -> int: def n_leaves(self) -> int:
@ -410,7 +448,7 @@ class Qube:
for c in node.children: for c in node.children:
if c.key in _keys: if c.key in _keys:
grandchildren = tuple(sorted(remove_key(cc) for cc in c.children)) grandchildren = tuple(sorted(remove_key(cc) for cc in c.children))
grandchildren = remove_key(Qube.root_node(grandchildren)).children grandchildren = remove_key(Qube.make_root(grandchildren)).children
children.extend(grandchildren) children.extend(grandchildren)
else: else:
children.append(remove_key(c)) children.append(remove_key(c))
@ -424,7 +462,7 @@ class Qube:
if node.key in converters: if node.key in converters:
converter = converters[node.key] converter = converters[node.key]
values = [converter(v) for v in node.values] values = [converter(v) for v in node.values]
new_node = node.replace(values=QEnum(values), dtype=type(values[0])) new_node = node.replace(values=QEnum(values))
return new_node return new_node
return node return node
@ -516,7 +554,8 @@ class Qube:
return node.replace( return node.replace(
children=new_children, children=new_children,
metadata=dict(self.metadata) | {"is_leaf": not bool(new_children)}, metadata=dict(self.metadata)
| ({"is_leaf": not bool(new_children)} if mode == "next_level" else {}),
) )
return self.replace( return self.replace(
@ -544,6 +583,26 @@ class Qube:
axes[self.key].update(self.values) axes[self.key].update(self.values)
return dict(axes) return dict(axes)
def axes_info(self, depth=0) -> dict[str, AxisInfo]:
axes = defaultdict(
lambda: AxisInfo(key="", type=str, depths=set(), values=set())
)
for c in self.children:
for k, info in c.axes_info(depth=depth + 1).items():
axes[k].combine(info)
if self.key != "root":
axes[self.key].combine(
AxisInfo(
key=self.key,
type=type(next(iter(self.values))),
depths={depth},
values=set(self.values),
)
)
return dict(axes)
@cached_property @cached_property
def structural_hash(self) -> int: def structural_hash(self) -> int:
""" """
@ -570,7 +629,7 @@ class Qube:
""" """
def union(a: Qube, b: Qube) -> Qube: def union(a: Qube, b: Qube) -> Qube:
b = type(self).root_node(children=(b,)) b = type(self).make_root(children=(b,))
out = set_operations.operation( out = set_operations.operation(
a, b, set_operations.SetOperation.UNION, type(self) a, b, set_operations.SetOperation.UNION, type(self)
) )
@ -583,3 +642,20 @@ class Qube:
) )
return self.replace(children=tuple(sorted(new_children))) return self.replace(children=tuple(sorted(new_children)))
def add_metadata(self, **kwargs: dict[str, Any]):
metadata = {
k: np.array(
[
v,
]
)
for k, v in kwargs.items()
}
return self.replace(metadata=metadata)
def strip_metadata(self) -> Qube:
def strip(node):
return node.replace(metadata=frozendict({}))
return self.transform(strip)

View File

@ -1,3 +1,4 @@
from . import protobuf
from .Qube import Qube from .Qube import Qube
__all__ = ["Qube"] __all__ = ["Qube", "protobuf"]

View File

@ -18,7 +18,7 @@ def make_node(
children: tuple[Qube, ...], children: tuple[Qube, ...],
metadata: dict[str, np.ndarray] | None = None, metadata: dict[str, np.ndarray] | None = None,
): ):
return cls.make( return cls.make_node(
key=key, key=key,
values=QEnum(values), values=QEnum(values),
metadata={k: np.array(v).reshape(shape) for k, v in metadata.items()} metadata={k: np.array(v).reshape(shape) for k, v in metadata.items()}
@ -39,5 +39,5 @@ def from_nodes(cls, nodes, add_root=True):
root = make_node(cls, shape=shape, children=(root,), key=key, **info) root = make_node(cls, shape=shape, children=(root,), key=key, **info)
if add_root: if add_root:
return cls.root_node(children=(root,)) return cls.make_root(children=(root,))
return root return root

View File

@ -1,27 +0,0 @@
from dataclasses import dataclass, field
import numpy as np
from frozendict import frozendict
from .value_types import ValueGroup
@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True)
class NodeData:
key: str
values: ValueGroup
metadata: frozendict[str, np.ndarray] = field(
default_factory=lambda: frozendict({}), compare=False
)
dtype: type = str
def summary(self) -> str:
return f"{self.key}={self.values.summary()}" if self.key != "root" else "root"
@dataclass(frozen=False, eq=True, order=True)
class RootNodeData(NodeData):
"Helper class to print a custom root name"
def summary(self) -> str:
return self.key

View File

View File

@ -0,0 +1,99 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from frozendict import frozendict
from ..value_types import QEnum
from . import qube_pb2
if TYPE_CHECKING:
from ..Qube import Qube
def _ndarray_to_proto(arr: np.ndarray) -> qube_pb2.NdArray:
"""np.ndarray → NdArray message"""
return qube_pb2.NdArray(
shape=list(arr.shape),
dtype=str(arr.dtype),
raw=arr.tobytes(order="C"),
)
def _ndarray_from_proto(msg: qube_pb2.NdArray) -> np.ndarray:
"""NdArray message → np.ndarray (immutable view)"""
return np.frombuffer(msg.raw, dtype=msg.dtype).reshape(tuple(msg.shape))
def _py_to_valuegroup(value: list[str] | np.ndarray) -> qube_pb2.ValueGroup:
"""Accept str-sequence *or* ndarray and return ValueGroup."""
vg = qube_pb2.ValueGroup()
if isinstance(value, np.ndarray):
vg.tensor.CopyFrom(_ndarray_to_proto(value))
else:
vg.s.items.extend(value)
return vg
def _valuegroup_to_py(vg: qube_pb2.ValueGroup) -> list[str] | np.ndarray:
"""ValueGroup → list[str] *or* ndarray"""
arm = vg.WhichOneof("payload")
if arm == "tensor":
return _ndarray_from_proto(vg.tensor)
return QEnum(vg.s.items)
def _py_to_metadatagroup(value: np.ndarray) -> qube_pb2.MetadataGroup:
"""Accept str-sequence *or* ndarray and return ValueGroup."""
vg = qube_pb2.MetadataGroup()
if not isinstance(value, np.ndarray):
value = np.array([value])
vg.tensor.CopyFrom(_ndarray_to_proto(value))
return vg
def _metadatagroup_to_py(vg: qube_pb2.MetadataGroup) -> np.ndarray:
"""ValueGroup → list[str] *or* ndarray"""
arm = vg.WhichOneof("payload")
if arm == "tensor":
return _ndarray_from_proto(vg.tensor)
raise ValueError(f"Unknown arm {arm}")
def _qube_to_proto(q: Qube) -> qube_pb2.Qube:
"""Frozen Qube dataclass → protobuf Qube message (new object)."""
return qube_pb2.Qube(
key=q.key,
values=_py_to_valuegroup(q.values),
metadata={k: _py_to_metadatagroup(v) for k, v in q.metadata.items()},
children=[_qube_to_proto(c) for c in q.children],
is_root=q.is_root,
)
def qube_to_proto(q: Qube) -> bytes:
return _qube_to_proto(q).SerializeToString()
def _proto_to_qube(cls: type, msg: qube_pb2.Qube) -> Qube:
"""protobuf Qube message → frozen Qube dataclass (new object)."""
return cls(
key=msg.key,
values=_valuegroup_to_py(msg.values),
metadata=frozendict(
{k: _metadatagroup_to_py(v) for k, v in msg.metadata.items()}
),
children=tuple(_proto_to_qube(cls, c) for c in msg.children),
is_root=msg.is_root,
)
def proto_to_qube(cls: type, wire: bytes) -> Qube:
msg = qube_pb2.Qube()
msg.ParseFromString(wire)
return _proto_to_qube(cls, msg)

View File

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: qube.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC, 5, 29, 0, "", "qube.proto"
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\nqube.proto"4\n\x07NdArray\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x0b\n\x03raw\x18\x03 \x01(\x0c"\x1c\n\x0bStringGroup\x12\r\n\x05items\x18\x01 \x03(\t"N\n\nValueGroup\x12\x19\n\x01s\x18\x01 \x01(\x0b\x32\x0c.StringGroupH\x00\x12\x1a\n\x06tensor\x18\x02 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"6\n\rMetadataGroup\x12\x1a\n\x06tensor\x18\x01 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"\xd1\x01\n\x04Qube\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1b\n\x06values\x18\x02 \x01(\x0b\x32\x0b.ValueGroup\x12%\n\x08metadata\x18\x03 \x03(\x0b\x32\x13.Qube.MetadataEntry\x12\r\n\x05\x64type\x18\x04 \x01(\t\x12\x17\n\x08\x63hildren\x18\x05 \x03(\x0b\x32\x05.Qube\x12\x0f\n\x07is_root\x18\x06 \x01(\x08\x1a?\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1d\n\x05value\x18\x02 \x01(\x0b\x32\x0e.MetadataGroup:\x02\x38\x01\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "qube_pb2", _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals["_QUBE_METADATAENTRY"]._loaded_options = None
_globals["_QUBE_METADATAENTRY"]._serialized_options = b"8\001"
_globals["_NDARRAY"]._serialized_start = 14
_globals["_NDARRAY"]._serialized_end = 66
_globals["_STRINGGROUP"]._serialized_start = 68
_globals["_STRINGGROUP"]._serialized_end = 96
_globals["_VALUEGROUP"]._serialized_start = 98
_globals["_VALUEGROUP"]._serialized_end = 176
_globals["_METADATAGROUP"]._serialized_start = 178
_globals["_METADATAGROUP"]._serialized_end = 232
_globals["_QUBE"]._serialized_start = 235
_globals["_QUBE"]._serialized_end = 444
_globals["_QUBE_METADATAENTRY"]._serialized_start = 381
_globals["_QUBE_METADATAENTRY"]._serialized_end = 444
# @@protoc_insertion_point(module_scope)

View File

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Iterable
import numpy as np import numpy as np
from frozendict import frozendict from frozendict import frozendict
from .node_types import NodeData
from .value_types import QEnum, ValueGroup, WildcardGroup from .value_types import QEnum, ValueGroup, WildcardGroup
if TYPE_CHECKING: if TYPE_CHECKING:
@ -27,7 +26,7 @@ class SetOperation(Enum):
@dataclass(eq=True, frozen=True) @dataclass(eq=True, frozen=True)
class ValuesMetadata: class ValuesMetadata:
values: ValueGroup values: ValueGroup
metadata: dict[str, np.ndarray] indices: list[int] | slice
def QEnum_intersection( def QEnum_intersection(
@ -49,19 +48,17 @@ def QEnum_intersection(
intersection_out = ValuesMetadata( intersection_out = ValuesMetadata(
values=QEnum(list(intersection.keys())), values=QEnum(list(intersection.keys())),
metadata={ indices=list(intersection.values()),
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
},
) )
just_A_out = ValuesMetadata( just_A_out = ValuesMetadata(
values=QEnum(list(just_A.keys())), values=QEnum(list(just_A.keys())),
metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()}, indices=list(just_A.values()),
) )
just_B_out = ValuesMetadata( just_B_out = ValuesMetadata(
values=QEnum(list(just_B.keys())), values=QEnum(list(just_B.keys())),
metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()}, indices=list(just_B.values()),
) )
return just_A_out, intersection_out, just_B_out return just_A_out, intersection_out, just_B_out
@ -76,61 +73,107 @@ def node_intersection(
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup): if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
return ( return (
ValuesMetadata(QEnum([]), {}), ValuesMetadata(QEnum([]), []),
ValuesMetadata(WildcardGroup(), {}), ValuesMetadata(WildcardGroup(), slice(None)),
ValuesMetadata(QEnum([]), {}), ValuesMetadata(QEnum([]), []),
) )
# If A is a wildcard matcher then the intersection is everything # If A is a wildcard matcher then the intersection is everything
# just_A is still * # just_A is still *
# just_B is empty # just_B is empty
if isinstance(A.values, WildcardGroup): if isinstance(A.values, WildcardGroup):
return A, B, ValuesMetadata(QEnum([]), {}) return A, B, ValuesMetadata(QEnum([]), [])
# The reverse if B is a wildcard # The reverse if B is a wildcard
if isinstance(B.values, WildcardGroup): if isinstance(B.values, WildcardGroup):
return ValuesMetadata(QEnum([]), {}), A, B return ValuesMetadata(QEnum([]), []), A, B
raise NotImplementedError( raise NotImplementedError(
f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented" f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
) )
def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube | None: def operation(
A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0
) -> Qube | None:
assert A.key == B.key, ( assert A.key == B.key, (
"The two Qube root nodes must have the same key to perform set operations," "The two Qube root nodes must have the same key to perform set operations,"
f"would usually be two root nodes. They have {A.key} and {B.key} respectively" f"would usually be two root nodes. They have {A.key} and {B.key} respectively"
) )
node_key = A.key
assert A.is_root == B.is_root
is_root = A.is_root
assert A.values == B.values, ( assert A.values == B.values, (
f"The two Qube root nodes must have the same values to perform set operations {A.values = }, {B.values = }" f"The two Qube root nodes must have the same values to perform set operations {A.values = }, {B.values = }"
) )
node_values = A.values
# Group the children of the two nodes by key # Group the children of the two nodes by key
nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict( nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict(
lambda: ([], []) lambda: ([], [])
) )
for node in A.children:
nodes_by_key[node.key][0].append(node)
for node in B.children:
nodes_by_key[node.key][1].append(node)
new_children: list[Qube] = [] new_children: list[Qube] = []
# Sort out metadata into what can stay at this level and what must move down
stayput_metadata: dict[str, np.ndarray] = {}
pushdown_metadata_A: dict[str, np.ndarray] = {}
pushdown_metadata_B: dict[str, np.ndarray] = {}
for key in set(A.metadata.keys()) | set(B.metadata.keys()):
if key not in A.metadata:
raise ValueError(f"B has key {key} but A does not. {A = } {B = }")
if key not in B.metadata:
raise ValueError(f"A has key {key} but B does not. {A = } {B = }")
print(f"{key = } {A.metadata[key] = } {B.metadata[key]}")
A_val = A.metadata[key]
B_val = B.metadata[key]
if A_val == B_val:
print(f"{' ' * depth}Keeping metadata key '{key}' at this level")
stayput_metadata[key] = A.metadata[key]
else:
print(f"{' ' * depth}Pushing down metadata key '{key}' {A_val} {B_val}")
pushdown_metadata_A[key] = A_val
pushdown_metadata_B[key] = B_val
# Add all the metadata that needs to be pushed down to the child nodes
# When pushing down the metadata we need to account for the fact it now affects more values
# So expand the metadata entries from shape (a, b, ..., c) to (a, b, ..., c, d)
# where d is the length of the node values
for node in A.children:
N = len(node.values)
print(N)
meta = {
k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,))
for k, v in pushdown_metadata_A.items()
}
node = node.replace(metadata=node.metadata | meta)
nodes_by_key[node.key][0].append(node)
for node in B.children:
N = len(node.values)
meta = {
k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,))
for k, v in pushdown_metadata_B.items()
}
node = node.replace(metadata=node.metadata | meta)
nodes_by_key[node.key][1].append(node)
# For every node group, perform the set operation # For every node group, perform the set operation
for key, (A_nodes, B_nodes) in nodes_by_key.items(): for key, (A_nodes, B_nodes) in nodes_by_key.items():
output = list(_operation(key, A_nodes, B_nodes, operation_type, node_type)) output = list(
_operation(key, A_nodes, B_nodes, operation_type, node_type, depth + 1)
)
# print(f"{' '*depth}_operation {operation_type.name} {A_nodes} {B_nodes} out = [{output}]")
new_children.extend(output) new_children.extend(output)
# print(f"operation {operation_type}: {A}, {B} {new_children = }") # print(f"{' '*depth}operation {operation_type.name} [{A}] [{B}] new_children = [{new_children}]")
# print(f"{A.children = }")
# print(f"{B.children = }")
# print(f"{new_children = }")
# If there are now no children as a result of the operation, return nothing. # If there are now no children as a result of the operation, return nothing.
if (A.children or B.children) and not new_children: if (A.children or B.children) and not new_children:
if A.key == "root": if A.key == "root":
return A.replace(children=()) return node_type.make_root(children=())
else: else:
return None return None
@ -140,20 +183,34 @@ def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube
new_children = list(compress_children(new_children)) new_children = list(compress_children(new_children))
# The values and key are the same so we just replace the children # The values and key are the same so we just replace the children
return A.replace(children=tuple(sorted(new_children))) return node_type.make_node(
key=node_key,
values=node_values,
children=new_children,
metadata=stayput_metadata,
is_root=is_root,
)
def get_indices(metadata: dict[str, np.ndarray], indices: list[int] | slice):
return {k: v[..., indices] for k, v in metadata.items()}
# The root node is special so we need a helper method that we can recurse on
def _operation( def _operation(
key: str, A: list[Qube], B: list[Qube], operation_type: SetOperation, node_type key: str,
A: list[Qube],
B: list[Qube],
operation_type: SetOperation,
node_type,
depth: int,
) -> Iterable[Qube]: ) -> Iterable[Qube]:
keep_just_A, keep_intersection, keep_just_B = operation_type.value keep_just_A, keep_intersection, keep_just_B = operation_type.value
# Iterate over all pairs (node_A, node_B)
values = {} values = {}
for node in A + B: for node in A + B:
values[node] = ValuesMetadata(node.values, node.metadata) values[node] = ValuesMetadata(node.values, node.metadata)
# Iterate over all pairs (node_A, node_B)
for node_a in A: for node_a in A:
for node_b in B: for node_b in B:
# Compute A - B, A & B, B - A # Compute A - B, A & B, B - A
@ -171,17 +228,21 @@ def _operation(
if intersection.values: if intersection.values:
new_node_a = node_a.replace( new_node_a = node_a.replace(
values=intersection.values, values=intersection.values,
metadata=intersection.metadata, metadata=get_indices(node_a.metadata, intersection.indices),
) )
new_node_b = node_b.replace( new_node_b = node_b.replace(
values=intersection.values, values=intersection.values,
metadata=intersection.metadata, metadata=get_indices(node_b.metadata, intersection.indices),
) )
# print(f"{node_a = }") # print(f"{' '*depth}{node_a = }")
# print(f"{node_b = }") # print(f"{' '*depth}{node_b = }")
# print(f"{intersection.values =}") # print(f"{' '*depth}{intersection.values =}")
result = operation( result = operation(
new_node_a, new_node_b, operation_type, node_type new_node_a,
new_node_b,
operation_type,
node_type,
depth=depth + 1,
) )
if result is not None: if result is not None:
yield result yield result
@ -190,20 +251,20 @@ def _operation(
if keep_just_A: if keep_just_A:
for node in A: for node in A:
if values[node].values: if values[node].values:
yield node_type.make( yield node_type.make_node(
key, key,
children=node.children, children=node.children,
values=values[node].values, values=values[node].values,
metadata=values[node].metadata, metadata=get_indices(node.metadata, values[node].indices),
) )
if keep_just_B: if keep_just_B:
for node in B: for node in B:
if values[node].values: if values[node].values:
yield node_type.make( yield node_type.make_node(
key, key,
children=node.children, children=node.children,
values=values[node].values, values=values[node].values,
metadata=values[node].metadata, metadata=get_indices(node.metadata, values[node].indices),
) )
@ -230,7 +291,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
key = child_list[0].key key = child_list[0].key
# Compress the children into a single node # Compress the children into a single node
assert all(isinstance(child.data.values, QEnum) for child in child_list), ( assert all(isinstance(child.values, QEnum) for child in child_list), (
"All children must have QEnum values" "All children must have QEnum values"
) )
@ -241,19 +302,19 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
metadata: frozendict[str, np.ndarray] = frozendict( metadata: frozendict[str, np.ndarray] = frozendict(
{ {
k: np.concatenate(metadata_group, axis=0) k: np.concatenate(metadata_group, axis=-1)
for k, metadata_group in metadata_groups.items() for k, metadata_group in metadata_groups.items()
} }
) )
node_data = NodeData(
key=key,
metadata=metadata,
values=QEnum(set(v for child in child_list for v in child.data.values)),
)
children = [cc for c in child_list for cc in c.children] children = [cc for c in child_list for cc in c.children]
compressed_children = compress_children(children) compressed_children = compress_children(children)
new_child = node_type(data=node_data, children=compressed_children) new_child = node_type.make_node(
key=key,
metadata=metadata,
values=QEnum(set(v for child in child_list for v in child.values)),
children=compressed_children,
)
else: else:
# If the group is size one just keep it # If the group is size one just keep it
new_child = child_list.pop() new_child = child_list.pop()

View File

@ -2,9 +2,8 @@ from __future__ import annotations
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterable from typing import TYPE_CHECKING, Callable, Iterable
import numpy as np
if TYPE_CHECKING: if TYPE_CHECKING:
from .Qube import Qube from .Qube import Qube
@ -71,27 +70,38 @@ def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[st
def summarize_node_html( def summarize_node_html(
node: Qube, collapse=False, max_summary_length=50, **kwargs node: Qube,
collapse=False,
max_summary_length=50,
info: Callable[[Qube], str] | None = None,
**kwargs,
) -> tuple[str, Qube]: ) -> tuple[str, Qube]:
""" """
Extracts a summarized representation of the node while collapsing single-child paths. Extracts a summarized representation of the node while collapsing single-child paths.
Returns the summary string and the last node in the chain that has multiple children. Returns the summary string and the last node in the chain that has multiple children.
""" """
if info is None:
def info_func(node: Qube, /):
return (
# f"dtype: {node.dtype}\n"
f"metadata: {dict(node.metadata)}\n"
)
else:
info_func = info
summaries = [] summaries = []
while True: while True:
path = node.summary(**kwargs) path = node.summary(**kwargs)
summary = path summary = path
if "is_leaf" in node.metadata and node.metadata["is_leaf"]:
summary += " 🌿"
if len(summary) > max_summary_length: if len(summary) > max_summary_length:
summary = summary[:max_summary_length] + "..." summary = summary[:max_summary_length] + "..."
info = (
f"dtype: {node.dtype.__name__}\n" info_string = info_func(node)
f"metadata: {dict((k, np.shape(v)) for k, v in node.metadata.items())}\n"
) summary = f'<span class="qubed-node" data-path="{path}" title="{info_string}">{summary}</span>'
summary = f'<span class="qubed-node" data-path="{path}" title="{info}">{summary}</span>'
summaries.append(summary) summaries.append(summary)
if not collapse: if not collapse:
break break
@ -105,9 +115,14 @@ def summarize_node_html(
def _node_tree_to_html( def _node_tree_to_html(
node: Qube, prefix: str = "", depth=1, connector="", **kwargs node: Qube,
prefix: str = "",
depth=1,
connector="",
info: Callable[[Qube], str] | None = None,
**kwargs,
) -> Iterable[str]: ) -> Iterable[str]:
summary, node = summarize_node_html(node, **kwargs) summary, node = summarize_node_html(node, info=info, **kwargs)
if len(node.children) == 0: if len(node.children) == 0:
yield f'<span class="qubed-level">{connector}{summary}</span>' yield f'<span class="qubed-level">{connector}{summary}</span>'
@ -124,13 +139,20 @@ def _node_tree_to_html(
prefix + extension, prefix + extension,
depth=depth - 1, depth=depth - 1,
connector=prefix + connector, connector=prefix + connector,
info=info,
**kwargs, **kwargs,
) )
yield "</details>" yield "</details>"
def node_tree_to_html( def node_tree_to_html(
node: Qube, depth=1, include_css=True, include_js=True, css_id=None, **kwargs node: Qube,
depth=1,
include_css=True,
include_js=True,
css_id=None,
info: Callable[[Qube], str] | None = None,
**kwargs,
) -> str: ) -> str:
if css_id is None: if css_id is None:
css_id = f"qubed-tree-{random.randint(0, 1000000)}" css_id = f"qubed-tree-{random.randint(0, 1000000)}"
@ -215,5 +237,5 @@ def node_tree_to_html(
nodes.forEach(n => n.addEventListener("click", nodeOnClick)); nodes.forEach(n => n.addEventListener("click", nodeOnClick));
</script> </script>
""".replace("CSS_ID", css_id) """.replace("CSS_ID", css_id)
nodes = "".join(_node_tree_to_html(node=node, depth=depth, **kwargs)) nodes = "".join(_node_tree_to_html(node=node, depth=depth, info=info, **kwargs))
return f"{js if include_js else ''}{css if include_css else ''}<pre class='qubed-tree' id='{css_id}'>{nodes}</pre>" return f"{js if include_js else ''}{css if include_css else ''}<pre class='qubed-tree' id='{css_id}'>{nodes}</pre>"

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import dataclasses import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, replace from dataclasses import dataclass
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -21,6 +21,11 @@ if TYPE_CHECKING:
@dataclass(frozen=True) @dataclass(frozen=True)
class ValueGroup(ABC): class ValueGroup(ABC):
@abstractmethod
def dtype(self) -> str:
"Provide a string rep of the datatype of these values"
pass
@abstractmethod @abstractmethod
def summary(self) -> str: def summary(self) -> str:
"Provide a string summary of the value group." "Provide a string summary of the value group."
@ -69,9 +74,13 @@ class QEnum(ValueGroup):
""" """
values: EnumValuesType values: EnumValuesType
_dtype: str = "str"
def __init__(self, obj): def __init__(self, obj):
object.__setattr__(self, "values", tuple(sorted(obj))) object.__setattr__(self, "values", tuple(sorted(obj)))
object.__setattr__(
self, "dtype", type(self.values[0]) if len(self.values) > 0 else "str"
)
def __post_init__(self): def __post_init__(self):
assert isinstance(self.values, tuple) assert isinstance(self.values, tuple)
@ -88,6 +97,9 @@ class QEnum(ValueGroup):
def __contains__(self, value: Any) -> bool: def __contains__(self, value: Any) -> bool:
return value in self.values return value in self.values
def dtype(self):
return self._dtype
@classmethod @classmethod
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]: def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
return [cls(tuple(values))] return [cls(tuple(values))]
@ -114,7 +126,7 @@ class WildcardGroup(ValueGroup):
return "*" return "*"
def __len__(self): def __len__(self):
return None return 1
def __iter__(self): def __iter__(self):
return ["*"] return ["*"]
@ -122,6 +134,9 @@ class WildcardGroup(ValueGroup):
def __bool__(self): def __bool__(self):
return True return True
def dtype(self):
return "*"
@classmethod @classmethod
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]: def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
return [WildcardGroup()] return [WildcardGroup()]
@ -398,7 +413,7 @@ def convert_datatypes(q: "Qube", conversions: dict[str, ValueGroup]) -> "Qube":
) )
for values_group in data_type.from_strings(q.values): for values_group in data_type.from_strings(q.values):
# print(values_group) # print(values_group)
yield replace(q, data=replace(q.data, values=values_group)) yield q.replace(values=values_group)
else: else:
yield q yield q

32
src/qube.proto Normal file
View File

@ -0,0 +1,32 @@
syntax = "proto3";
message NdArray {
repeated int64 shape = 1;
string dtype = 2;
bytes raw = 3;
}
message StringGroup {repeated string items = 1; }
// Stores values i.e class=1/2/3 the 1/2/3 part
message ValueGroup {
oneof payload {
StringGroup s = 1;
NdArray tensor = 2;
}
}
message MetadataGroup {
oneof payload {
NdArray tensor = 1;
}
}
message Qube {
string key = 1;
ValueGroup values = 2;
map<string, MetadataGroup> metadata = 3;
string dtype = 4;
repeated Qube children = 5;
bool is_root = 6;
}

View File

@ -176,7 +176,7 @@ def follow_query(request: dict[str, str | list[str]], qube: Qube):
by_path = defaultdict(lambda: {"paths": set(), "values": set()}) by_path = defaultdict(lambda: {"paths": set(), "values": set()})
for request, node in s.leaf_nodes(): for request, node in s.leaf_nodes():
if not node.data.metadata["is_leaf"]: if not node.metadata["is_leaf"]:
by_path[node.key]["values"].update(node.values.values) by_path[node.key]["values"].update(node.values.values)
by_path[node.key]["paths"].add(frozendict(request)) by_path[node.key]["paths"].add(frozendict(request))

View File

@ -1,36 +1,40 @@
from qubed import Qube from qubed import Qube
d = { q = Qube.from_tree("""
"class=od": { root
"expver=0001": {"param=1": {}, "param=2": {}}, class=od
"expver=0002": {"param=1": {}, "param=2": {}}, expver=0001
}, param=1
"class=rd": { param=2
"expver=0001": {"param=1": {}, "param=2": {}, "param=3": {}}, expver=0002
"expver=0002": {"param=1": {}, "param=2": {}}, param=1
}, param=2
} class=rd
q = Qube.from_dict(d) expver=0001
param=1
param=2
def test_eq(): param=3
r = Qube.from_dict(d) expver=0002
assert q == r param=1
param=2
""")
def test_getitem(): def test_getitem():
assert q["class", "od"] == Qube.from_dict( assert q["class", "od"] == Qube.from_tree("""
{ root
"expver=0001": {"param=1": {}, "param=2": {}}, expver=0001
"expver=0002": {"param=1": {}, "param=2": {}}, param=1
} param=2
) expver=0002
assert q["class", "od"]["expver", "0001"] == Qube.from_dict( param=1
{ param=2
"param=1": {}, """)
"param=2": {},
} assert q["class", "od"]["expver", "0001"] == Qube.from_tree("""
) root
param=1
param=2""")
def test_n_leaves(): def test_n_leaves():

View File

@ -2,7 +2,7 @@ from qubed import Qube
def test_json_round_trip(): def test_json_round_trip():
u = Qube.from_dict( from_dict = Qube.from_dict(
{ {
"class=d1": { "class=d1": {
"dataset=climate-dt/weather-dt": { "dataset=climate-dt/weather-dt": {
@ -14,5 +14,54 @@ def test_json_round_trip():
} }
} }
) )
json = u.to_json()
assert Qube.from_json(json) == u from_tree = Qube.from_tree("""
root, class=d1
dataset=another-value, generation=1/2/3
dataset=climate-dt/weather-dt, generation=1/2/3/4
""")
from_json = Qube.from_json(
{
"key": "root",
"values": ["root"],
"metadata": {},
"children": [
{
"key": "class",
"values": ["d1"],
"metadata": {},
"children": [
{
"key": "dataset",
"values": ["another-value"],
"metadata": {},
"children": [
{
"key": "generation",
"values": ["1", "2", "3"],
"metadata": {},
"children": [],
}
],
},
{
"key": "dataset",
"values": ["climate-dt", "weather-dt"],
"metadata": {},
"children": [
{
"key": "generation",
"values": ["1", "2", "3", "4"],
"metadata": {},
"children": [],
}
],
},
],
}
],
}
)
assert from_tree == from_json
assert from_tree == from_dict

View File

@ -20,14 +20,6 @@ root
expver=0002, param=1/2 expver=0002, param=1/2
""".strip() """.strip()
as_html = """
<details open><summary class="qubed-level"><span class="qubed-node" data-path="root" title="dtype: str\nmetadata: {}\n">root</span></summary><span class="qubed-level"> <span class="qubed-node" data-path="class=od" title="dtype: str\nmetadata: {}\n">class=od</span>, <span class="qubed-node" data-path="expver=0001/0002" title="dtype: str\nmetadata: {}\n">expver=0001/0002</span>, <span class="qubed-node" data-path="param=1/2" title="dtype: str\nmetadata: {}\n">param=1/2</span></span><details open><summary class="qubed-level"> <span class="qubed-node" data-path="class=rd" title="dtype: str\nmetadata: {}\n">class=rd</span></summary><span class="qubed-level"> <span class="qubed-node" data-path="expver=0001" title="dtype: str\nmetadata: {}\n">expver=0001</span>, <span class="qubed-node" data-path="param=1/2/3" title="dtype: str\nmetadata: {}\n">param=1/2/3</span></span><span class="qubed-level"> <span class="qubed-node" data-path="expver=0002" title="dtype: str\nmetadata: {}\n">expver=0002</span>, <span class="qubed-node" data-path="param=1/2" title="dtype: str\nmetadata: {}\n">param=1/2</span></span></details></details>
""".strip()
def test_string(): def test_string():
assert str(q).strip() == as_string assert str(q).strip() == as_string
def test_html():
assert as_html in q._repr_html_()

View File

@ -1,45 +1,44 @@
from frozendict import frozendict from frozendict import frozendict
from qubed import Qube
def make_set(entries): def make_set(entries):
return set((frozendict(a), frozendict(b)) for a, b in entries) return set((frozendict(a), frozendict(b)) for a, b in entries)
def test_simple_union(): # def test_simple_union():
q = Qube.from_nodes( # q = Qube.from_nodes(
{ # {
"class": dict(values=["od", "rd"]), # "class": dict(values=["od", "rd"]),
"expver": dict(values=[1, 2]), # "expver": dict(values=[1, 2]),
"stream": dict( # "stream": dict(
values=["a", "b", "c"], metadata=dict(number=list(range(12))) # values=["a", "b", "c"], metadata=dict(number=list(range(12)))
), # ),
} # }
) # )
r = Qube.from_nodes( # r = Qube.from_nodes(
{ # {
"class": dict(values=["xd"]), # "class": dict(values=["xd"]),
"expver": dict(values=[1, 2]), # "expver": dict(values=[1, 2]),
"stream": dict( # "stream": dict(
values=["a", "b", "c"], metadata=dict(number=list(range(12, 18))) # values=["a", "b", "c"], metadata=dict(number=list(range(12, 18)))
), # ),
} # }
) # )
expected_union = Qube.from_nodes( # expected_union = Qube.from_nodes(
{ # {
"class": dict(values=["od", "rd", "xd"]), # "class": dict(values=["od", "rd", "xd"]),
"expver": dict(values=[1, 2]), # "expver": dict(values=[1, 2]),
"stream": dict( # "stream": dict(
values=["a", "b", "c"], metadata=dict(number=list(range(18))) # values=["a", "b", "c"], metadata=dict(number=list(range(18)))
), # ),
} # }
) # )
union = q | r # union = q | r
assert union == expected_union # assert union == expected_union
assert make_set(expected_union.leaves_with_metadata()) == make_set( # assert make_set(expected_union.leaves_with_metadata()) == make_set(
union.leaves_with_metadata() # union.leaves_with_metadata()
) # )

12
tests/test_protobuf.py Normal file
View File

@ -0,0 +1,12 @@
from qubed import Qube
def test_protobuf_simple():
q = Qube.from_tree("""
root, class=d1
dataset=another-value, generation=1/2/3
dataset=climate-dt/weather-dt, generation=1/2/3/4
""")
wire = q.to_protobuf()
round_trip = Qube.from_protobuf(wire)
assert round_trip == q