Massive rewrite
This commit is contained in:
parent
ed4a9055fa
commit
35bb8f0edd
@ -8,17 +8,17 @@ import functools
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, Iterator, Literal, Sequence
|
from typing import Any, Iterable, Iterator, Literal, Self, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from . import set_operations
|
from . import set_operations
|
||||||
from .metadata import from_nodes
|
from .metadata import from_nodes
|
||||||
from .node_types import NodeData, RootNodeData
|
from .protobuf.adapters import proto_to_qube, qube_to_proto
|
||||||
from .tree_formatters import (
|
from .tree_formatters import (
|
||||||
HTML,
|
HTML,
|
||||||
node_tree_to_html,
|
node_tree_to_html,
|
||||||
@ -32,58 +32,90 @@ from .value_types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True)
|
@dataclass
|
||||||
class Qube:
|
class AxisInfo:
|
||||||
data: NodeData
|
key: str
|
||||||
children: tuple[Qube, ...]
|
type: Any
|
||||||
|
depths: set[int]
|
||||||
|
values: set
|
||||||
|
|
||||||
@property
|
def combine(self, other: Self):
|
||||||
def key(self) -> str:
|
self.key = other.key
|
||||||
return self.data.key
|
self.type = other.type
|
||||||
|
self.depths.update(other.depths)
|
||||||
|
self.values.update(other.values)
|
||||||
|
# print(f"combining {self} and {other} getting {result}")
|
||||||
|
|
||||||
@property
|
def to_json(self):
|
||||||
def values(self) -> ValueGroup:
|
return {
|
||||||
return self.data.values
|
"key": self.key,
|
||||||
|
"type": self.type.__name__,
|
||||||
@property
|
"values": list(self.values),
|
||||||
def metadata(self):
|
"depths": list(self.depths),
|
||||||
return self.data.metadata
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.data.dtype
|
|
||||||
|
|
||||||
def replace(self, **kwargs) -> Qube:
|
|
||||||
data_keys = {
|
|
||||||
k: v
|
|
||||||
for k, v in kwargs.items()
|
|
||||||
if k in ["key", "values", "metadata", "dtype"]
|
|
||||||
}
|
}
|
||||||
node_keys = {k: v for k, v in kwargs.items() if k == "children"}
|
|
||||||
if not data_keys and not node_keys:
|
|
||||||
return self
|
|
||||||
if not data_keys:
|
|
||||||
return dataclasses.replace(self, **node_keys)
|
|
||||||
|
|
||||||
return dataclasses.replace(
|
|
||||||
self, data=dataclasses.replace(self.data, **data_keys), **node_keys
|
@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True)
|
||||||
)
|
class QubeNamedRoot:
|
||||||
|
"Helper class to print a custom root name"
|
||||||
|
|
||||||
|
key: str
|
||||||
|
dtype: str = "str"
|
||||||
|
children: tuple[Qube, ...] = ()
|
||||||
|
|
||||||
def summary(self) -> str:
|
def summary(self) -> str:
|
||||||
return self.data.summary()
|
return self.key
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True)
|
||||||
|
class Qube:
|
||||||
|
key: str
|
||||||
|
values: ValueGroup
|
||||||
|
metadata: frozendict[str, np.ndarray] = field(
|
||||||
|
default_factory=lambda: frozendict({}), compare=False
|
||||||
|
)
|
||||||
|
children: tuple[Qube, ...] = ()
|
||||||
|
is_root: bool = False
|
||||||
|
|
||||||
|
def replace(self, **kwargs) -> Qube:
|
||||||
|
return dataclasses.replace(self, **kwargs)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
if self.is_root:
|
||||||
|
return self.key
|
||||||
|
return f"{self.key}={self.values.summary()}" if self.key != "root" else "root"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube:
|
def make_node(
|
||||||
|
cls,
|
||||||
|
key: str,
|
||||||
|
values: Iterable | QEnum | WildcardGroup,
|
||||||
|
children: Iterable[Qube],
|
||||||
|
metadata: dict[str, np.ndarray] = {},
|
||||||
|
is_root: bool = False,
|
||||||
|
) -> Qube:
|
||||||
|
if isinstance(values, ValueGroup):
|
||||||
|
values = values
|
||||||
|
else:
|
||||||
|
values = QEnum(values)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
data=NodeData(
|
key,
|
||||||
key, values, metadata=frozendict(kwargs.get("metadata", frozendict()))
|
values=values,
|
||||||
),
|
|
||||||
children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))),
|
children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))),
|
||||||
|
metadata=frozendict(metadata),
|
||||||
|
is_root=is_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def root_node(cls, children: Iterable[Qube]) -> Qube:
|
def make_root(cls, children: Iterable[Qube], metadata={}) -> Qube:
|
||||||
return cls.make("root", QEnum(("root",)), children)
|
return cls.make_node(
|
||||||
|
"root",
|
||||||
|
values=QEnum(("root",)),
|
||||||
|
children=children,
|
||||||
|
metadata=metadata,
|
||||||
|
is_root=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str | Path) -> Qube:
|
def load(cls, path: str | Path) -> Qube:
|
||||||
@ -104,18 +136,19 @@ class Qube:
|
|||||||
else:
|
else:
|
||||||
values_group = QEnum([values])
|
values_group = QEnum([values])
|
||||||
|
|
||||||
children = [cls.make(key, values_group, children)]
|
children = [cls.make_node(key, values_group, children)]
|
||||||
|
|
||||||
return cls.root_node(children)
|
return cls.make_root(children)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, json: dict) -> Qube:
|
def from_json(cls, json: dict) -> Qube:
|
||||||
def from_json(json: dict) -> Qube:
|
def from_json(json: dict, depth=0) -> Qube:
|
||||||
return Qube.make(
|
return Qube.make_node(
|
||||||
key=json["key"],
|
key=json["key"],
|
||||||
values=values_from_json(json["values"]),
|
values=values_from_json(json["values"]),
|
||||||
metadata=frozendict(json["metadata"]) if "metadata" in json else {},
|
metadata=frozendict(json["metadata"]) if "metadata" in json else {},
|
||||||
children=(from_json(c) for c in json["children"]),
|
children=(from_json(c, depth + 1) for c in json["children"]),
|
||||||
|
is_root=(depth == 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
return from_json(json)
|
return from_json(json)
|
||||||
@ -146,13 +179,13 @@ class Qube:
|
|||||||
else:
|
else:
|
||||||
values = QEnum(values)
|
values = QEnum(values)
|
||||||
|
|
||||||
yield Qube.make(
|
yield Qube.make_node(
|
||||||
key=key,
|
key=key,
|
||||||
values=values,
|
values=values,
|
||||||
children=from_dict(children),
|
children=from_dict(children),
|
||||||
)
|
)
|
||||||
|
|
||||||
return Qube.root_node(list(from_dict(d)))
|
return Qube.make_root(list(from_dict(d)))
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
def to_dict(q: Qube) -> tuple[str, dict]:
|
def to_dict(q: Qube) -> tuple[str, dict]:
|
||||||
@ -161,6 +194,13 @@ class Qube:
|
|||||||
|
|
||||||
return to_dict(self)[1]
|
return to_dict(self)[1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_protobuf(cls, msg: bytes) -> Qube:
|
||||||
|
return proto_to_qube(cls, msg)
|
||||||
|
|
||||||
|
def to_protobuf(self) -> bytes:
|
||||||
|
return qube_to_proto(self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tree(cls, tree_str):
|
def from_tree(cls, tree_str):
|
||||||
lines = tree_str.splitlines()
|
lines = tree_str.splitlines()
|
||||||
@ -214,17 +254,12 @@ class Qube:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> Qube:
|
def empty(cls) -> Qube:
|
||||||
return Qube.root_node([])
|
return Qube.make_root([])
|
||||||
|
|
||||||
def __str_helper__(self, depth=None, name=None) -> str:
|
def __str_helper__(self, depth=None, name=None) -> str:
|
||||||
node = (
|
node = self
|
||||||
dataclasses.replace(
|
if name is not None:
|
||||||
self,
|
node = node.replace(key=name)
|
||||||
data=RootNodeData(key=name, values=self.values, metadata=self.metadata),
|
|
||||||
)
|
|
||||||
if name is not None
|
|
||||||
else self
|
|
||||||
)
|
|
||||||
out = "".join(node_tree_to_string(node=node, depth=depth))
|
out = "".join(node_tree_to_string(node=node, depth=depth))
|
||||||
if out[-1] == "\n":
|
if out[-1] == "\n":
|
||||||
out = out[:-1]
|
out = out[:-1]
|
||||||
@ -239,16 +274,19 @@ class Qube:
|
|||||||
def print(self, depth=None, name: str | None = None):
|
def print(self, depth=None, name: str | None = None):
|
||||||
print(self.__str_helper__(depth=depth, name=name))
|
print(self.__str_helper__(depth=depth, name=name))
|
||||||
|
|
||||||
def html(self, depth=2, collapse=True, name: str | None = None) -> HTML:
|
def html(
|
||||||
node = (
|
self,
|
||||||
dataclasses.replace(
|
depth=2,
|
||||||
self,
|
collapse=True,
|
||||||
data=RootNodeData(key=name, values=self.values, metadata=self.metadata),
|
name: str | None = None,
|
||||||
)
|
info: Callable[[Qube], str] | None = None,
|
||||||
if name is not None
|
) -> HTML:
|
||||||
else self
|
node = self
|
||||||
|
if name is not None:
|
||||||
|
node = node.replace(key=name)
|
||||||
|
return HTML(
|
||||||
|
node_tree_to_html(node=node, depth=depth, collapse=collapse, info=info)
|
||||||
)
|
)
|
||||||
return HTML(node_tree_to_html(node=node, depth=depth, collapse=collapse))
|
|
||||||
|
|
||||||
def _repr_html_(self) -> str:
|
def _repr_html_(self) -> str:
|
||||||
return node_tree_to_html(self, depth=2, collapse=True)
|
return node_tree_to_html(self, depth=2, collapse=True)
|
||||||
@ -257,7 +295,7 @@ class Qube:
|
|||||||
def __rtruediv__(self, other: str) -> Qube:
|
def __rtruediv__(self, other: str) -> Qube:
|
||||||
key, values = other.split("=")
|
key, values = other.split("=")
|
||||||
values_enum = QEnum((values.split("/")))
|
values_enum = QEnum((values.split("/")))
|
||||||
return Qube.root_node([Qube.make(key, values_enum, self.children)])
|
return Qube.make_root([Qube.make_node(key, values_enum, self.children)])
|
||||||
|
|
||||||
def __or__(self, other: Qube) -> Qube:
|
def __or__(self, other: Qube) -> Qube:
|
||||||
return set_operations.operation(
|
return set_operations.operation(
|
||||||
@ -358,16 +396,16 @@ class Qube:
|
|||||||
raise KeyError(
|
raise KeyError(
|
||||||
f"Key '{key}' not found in children of '{current.key}', available keys are {[c.key for c in current.children]}"
|
f"Key '{key}' not found in children of '{current.key}', available keys are {[c.key for c in current.children]}"
|
||||||
)
|
)
|
||||||
return Qube.root_node(current.children)
|
return Qube.make_root(current.children)
|
||||||
|
|
||||||
elif isinstance(args, tuple) and len(args) == 2:
|
elif isinstance(args, tuple) and len(args) == 2:
|
||||||
key, value = args
|
key, value = args
|
||||||
for c in self.children:
|
for c in self.children:
|
||||||
if c.key == key and value in c.values:
|
if c.key == key and value in c.values:
|
||||||
return Qube.root_node(c.children)
|
return Qube.make_root(c.children)
|
||||||
raise KeyError(f"Key {key} not found in children of {self.key}")
|
raise KeyError(f"Key '{key}' not found in children of {self.key}")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown key type")
|
raise ValueError(f"Unknown key type {args}")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def n_leaves(self) -> int:
|
def n_leaves(self) -> int:
|
||||||
@ -410,7 +448,7 @@ class Qube:
|
|||||||
for c in node.children:
|
for c in node.children:
|
||||||
if c.key in _keys:
|
if c.key in _keys:
|
||||||
grandchildren = tuple(sorted(remove_key(cc) for cc in c.children))
|
grandchildren = tuple(sorted(remove_key(cc) for cc in c.children))
|
||||||
grandchildren = remove_key(Qube.root_node(grandchildren)).children
|
grandchildren = remove_key(Qube.make_root(grandchildren)).children
|
||||||
children.extend(grandchildren)
|
children.extend(grandchildren)
|
||||||
else:
|
else:
|
||||||
children.append(remove_key(c))
|
children.append(remove_key(c))
|
||||||
@ -424,7 +462,7 @@ class Qube:
|
|||||||
if node.key in converters:
|
if node.key in converters:
|
||||||
converter = converters[node.key]
|
converter = converters[node.key]
|
||||||
values = [converter(v) for v in node.values]
|
values = [converter(v) for v in node.values]
|
||||||
new_node = node.replace(values=QEnum(values), dtype=type(values[0]))
|
new_node = node.replace(values=QEnum(values))
|
||||||
return new_node
|
return new_node
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@ -516,7 +554,8 @@ class Qube:
|
|||||||
|
|
||||||
return node.replace(
|
return node.replace(
|
||||||
children=new_children,
|
children=new_children,
|
||||||
metadata=dict(self.metadata) | {"is_leaf": not bool(new_children)},
|
metadata=dict(self.metadata)
|
||||||
|
| ({"is_leaf": not bool(new_children)} if mode == "next_level" else {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.replace(
|
return self.replace(
|
||||||
@ -544,6 +583,26 @@ class Qube:
|
|||||||
axes[self.key].update(self.values)
|
axes[self.key].update(self.values)
|
||||||
return dict(axes)
|
return dict(axes)
|
||||||
|
|
||||||
|
def axes_info(self, depth=0) -> dict[str, AxisInfo]:
|
||||||
|
axes = defaultdict(
|
||||||
|
lambda: AxisInfo(key="", type=str, depths=set(), values=set())
|
||||||
|
)
|
||||||
|
for c in self.children:
|
||||||
|
for k, info in c.axes_info(depth=depth + 1).items():
|
||||||
|
axes[k].combine(info)
|
||||||
|
|
||||||
|
if self.key != "root":
|
||||||
|
axes[self.key].combine(
|
||||||
|
AxisInfo(
|
||||||
|
key=self.key,
|
||||||
|
type=type(next(iter(self.values))),
|
||||||
|
depths={depth},
|
||||||
|
values=set(self.values),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return dict(axes)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def structural_hash(self) -> int:
|
def structural_hash(self) -> int:
|
||||||
"""
|
"""
|
||||||
@ -570,7 +629,7 @@ class Qube:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def union(a: Qube, b: Qube) -> Qube:
|
def union(a: Qube, b: Qube) -> Qube:
|
||||||
b = type(self).root_node(children=(b,))
|
b = type(self).make_root(children=(b,))
|
||||||
out = set_operations.operation(
|
out = set_operations.operation(
|
||||||
a, b, set_operations.SetOperation.UNION, type(self)
|
a, b, set_operations.SetOperation.UNION, type(self)
|
||||||
)
|
)
|
||||||
@ -583,3 +642,20 @@ class Qube:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self.replace(children=tuple(sorted(new_children)))
|
return self.replace(children=tuple(sorted(new_children)))
|
||||||
|
|
||||||
|
def add_metadata(self, **kwargs: dict[str, Any]):
|
||||||
|
metadata = {
|
||||||
|
k: np.array(
|
||||||
|
[
|
||||||
|
v,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
}
|
||||||
|
return self.replace(metadata=metadata)
|
||||||
|
|
||||||
|
def strip_metadata(self) -> Qube:
|
||||||
|
def strip(node):
|
||||||
|
return node.replace(metadata=frozendict({}))
|
||||||
|
|
||||||
|
return self.transform(strip)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from . import protobuf
|
||||||
from .Qube import Qube
|
from .Qube import Qube
|
||||||
|
|
||||||
__all__ = ["Qube"]
|
__all__ = ["Qube", "protobuf"]
|
||||||
|
@ -18,7 +18,7 @@ def make_node(
|
|||||||
children: tuple[Qube, ...],
|
children: tuple[Qube, ...],
|
||||||
metadata: dict[str, np.ndarray] | None = None,
|
metadata: dict[str, np.ndarray] | None = None,
|
||||||
):
|
):
|
||||||
return cls.make(
|
return cls.make_node(
|
||||||
key=key,
|
key=key,
|
||||||
values=QEnum(values),
|
values=QEnum(values),
|
||||||
metadata={k: np.array(v).reshape(shape) for k, v in metadata.items()}
|
metadata={k: np.array(v).reshape(shape) for k, v in metadata.items()}
|
||||||
@ -39,5 +39,5 @@ def from_nodes(cls, nodes, add_root=True):
|
|||||||
root = make_node(cls, shape=shape, children=(root,), key=key, **info)
|
root = make_node(cls, shape=shape, children=(root,), key=key, **info)
|
||||||
|
|
||||||
if add_root:
|
if add_root:
|
||||||
return cls.root_node(children=(root,))
|
return cls.make_root(children=(root,))
|
||||||
return root
|
return root
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from frozendict import frozendict
|
|
||||||
|
|
||||||
from .value_types import ValueGroup
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True)
|
|
||||||
class NodeData:
|
|
||||||
key: str
|
|
||||||
values: ValueGroup
|
|
||||||
metadata: frozendict[str, np.ndarray] = field(
|
|
||||||
default_factory=lambda: frozendict({}), compare=False
|
|
||||||
)
|
|
||||||
dtype: type = str
|
|
||||||
|
|
||||||
def summary(self) -> str:
|
|
||||||
return f"{self.key}={self.values.summary()}" if self.key != "root" else "root"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=False, eq=True, order=True)
|
|
||||||
class RootNodeData(NodeData):
|
|
||||||
"Helper class to print a custom root name"
|
|
||||||
|
|
||||||
def summary(self) -> str:
|
|
||||||
return self.key
|
|
0
src/python/qubed/protobuf/__init__.py
Normal file
0
src/python/qubed/protobuf/__init__.py
Normal file
99
src/python/qubed/protobuf/adapters.py
Normal file
99
src/python/qubed/protobuf/adapters.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
from ..value_types import QEnum
|
||||||
|
from . import qube_pb2
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..Qube import Qube
|
||||||
|
|
||||||
|
|
||||||
|
def _ndarray_to_proto(arr: np.ndarray) -> qube_pb2.NdArray:
|
||||||
|
"""np.ndarray → NdArray message"""
|
||||||
|
return qube_pb2.NdArray(
|
||||||
|
shape=list(arr.shape),
|
||||||
|
dtype=str(arr.dtype),
|
||||||
|
raw=arr.tobytes(order="C"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ndarray_from_proto(msg: qube_pb2.NdArray) -> np.ndarray:
|
||||||
|
"""NdArray message → np.ndarray (immutable view)"""
|
||||||
|
return np.frombuffer(msg.raw, dtype=msg.dtype).reshape(tuple(msg.shape))
|
||||||
|
|
||||||
|
|
||||||
|
def _py_to_valuegroup(value: list[str] | np.ndarray) -> qube_pb2.ValueGroup:
|
||||||
|
"""Accept str-sequence *or* ndarray and return ValueGroup."""
|
||||||
|
vg = qube_pb2.ValueGroup()
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
vg.tensor.CopyFrom(_ndarray_to_proto(value))
|
||||||
|
else:
|
||||||
|
vg.s.items.extend(value)
|
||||||
|
return vg
|
||||||
|
|
||||||
|
|
||||||
|
def _valuegroup_to_py(vg: qube_pb2.ValueGroup) -> list[str] | np.ndarray:
|
||||||
|
"""ValueGroup → list[str] *or* ndarray"""
|
||||||
|
arm = vg.WhichOneof("payload")
|
||||||
|
if arm == "tensor":
|
||||||
|
return _ndarray_from_proto(vg.tensor)
|
||||||
|
|
||||||
|
return QEnum(vg.s.items)
|
||||||
|
|
||||||
|
|
||||||
|
def _py_to_metadatagroup(value: np.ndarray) -> qube_pb2.MetadataGroup:
|
||||||
|
"""Accept str-sequence *or* ndarray and return ValueGroup."""
|
||||||
|
vg = qube_pb2.MetadataGroup()
|
||||||
|
if not isinstance(value, np.ndarray):
|
||||||
|
value = np.array([value])
|
||||||
|
|
||||||
|
vg.tensor.CopyFrom(_ndarray_to_proto(value))
|
||||||
|
return vg
|
||||||
|
|
||||||
|
|
||||||
|
def _metadatagroup_to_py(vg: qube_pb2.MetadataGroup) -> np.ndarray:
|
||||||
|
"""ValueGroup → list[str] *or* ndarray"""
|
||||||
|
arm = vg.WhichOneof("payload")
|
||||||
|
if arm == "tensor":
|
||||||
|
return _ndarray_from_proto(vg.tensor)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown arm {arm}")
|
||||||
|
|
||||||
|
|
||||||
|
def _qube_to_proto(q: Qube) -> qube_pb2.Qube:
|
||||||
|
"""Frozen Qube dataclass → protobuf Qube message (new object)."""
|
||||||
|
return qube_pb2.Qube(
|
||||||
|
key=q.key,
|
||||||
|
values=_py_to_valuegroup(q.values),
|
||||||
|
metadata={k: _py_to_metadatagroup(v) for k, v in q.metadata.items()},
|
||||||
|
children=[_qube_to_proto(c) for c in q.children],
|
||||||
|
is_root=q.is_root,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qube_to_proto(q: Qube) -> bytes:
|
||||||
|
return _qube_to_proto(q).SerializeToString()
|
||||||
|
|
||||||
|
|
||||||
|
def _proto_to_qube(cls: type, msg: qube_pb2.Qube) -> Qube:
|
||||||
|
"""protobuf Qube message → frozen Qube dataclass (new object)."""
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
key=msg.key,
|
||||||
|
values=_valuegroup_to_py(msg.values),
|
||||||
|
metadata=frozendict(
|
||||||
|
{k: _metadatagroup_to_py(v) for k, v in msg.metadata.items()}
|
||||||
|
),
|
||||||
|
children=tuple(_proto_to_qube(cls, c) for c in msg.children),
|
||||||
|
is_root=msg.is_root,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def proto_to_qube(cls: type, wire: bytes) -> Qube:
|
||||||
|
msg = qube_pb2.Qube()
|
||||||
|
msg.ParseFromString(wire)
|
||||||
|
return _proto_to_qube(cls, msg)
|
45
src/python/qubed/protobuf/qube_pb2.py
Normal file
45
src/python/qubed/protobuf/qube_pb2.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# NO CHECKED-IN PROTOBUF GENCODE
|
||||||
|
# source: qube.proto
|
||||||
|
# Protobuf Python Version: 5.29.0
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import runtime_version as _runtime_version
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
|
||||||
|
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||||
|
_runtime_version.Domain.PUBLIC, 5, 29, 0, "", "qube.proto"
|
||||||
|
)
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||||
|
b'\n\nqube.proto"4\n\x07NdArray\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x0b\n\x03raw\x18\x03 \x01(\x0c"\x1c\n\x0bStringGroup\x12\r\n\x05items\x18\x01 \x03(\t"N\n\nValueGroup\x12\x19\n\x01s\x18\x01 \x01(\x0b\x32\x0c.StringGroupH\x00\x12\x1a\n\x06tensor\x18\x02 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"6\n\rMetadataGroup\x12\x1a\n\x06tensor\x18\x01 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"\xd1\x01\n\x04Qube\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1b\n\x06values\x18\x02 \x01(\x0b\x32\x0b.ValueGroup\x12%\n\x08metadata\x18\x03 \x03(\x0b\x32\x13.Qube.MetadataEntry\x12\r\n\x05\x64type\x18\x04 \x01(\t\x12\x17\n\x08\x63hildren\x18\x05 \x03(\x0b\x32\x05.Qube\x12\x0f\n\x07is_root\x18\x06 \x01(\x08\x1a?\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1d\n\x05value\x18\x02 \x01(\x0b\x32\x0e.MetadataGroup:\x02\x38\x01\x62\x06proto3'
|
||||||
|
)
|
||||||
|
|
||||||
|
_globals = globals()
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "qube_pb2", _globals)
|
||||||
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals["_QUBE_METADATAENTRY"]._loaded_options = None
|
||||||
|
_globals["_QUBE_METADATAENTRY"]._serialized_options = b"8\001"
|
||||||
|
_globals["_NDARRAY"]._serialized_start = 14
|
||||||
|
_globals["_NDARRAY"]._serialized_end = 66
|
||||||
|
_globals["_STRINGGROUP"]._serialized_start = 68
|
||||||
|
_globals["_STRINGGROUP"]._serialized_end = 96
|
||||||
|
_globals["_VALUEGROUP"]._serialized_start = 98
|
||||||
|
_globals["_VALUEGROUP"]._serialized_end = 176
|
||||||
|
_globals["_METADATAGROUP"]._serialized_start = 178
|
||||||
|
_globals["_METADATAGROUP"]._serialized_end = 232
|
||||||
|
_globals["_QUBE"]._serialized_start = 235
|
||||||
|
_globals["_QUBE"]._serialized_end = 444
|
||||||
|
_globals["_QUBE_METADATAENTRY"]._serialized_start = 381
|
||||||
|
_globals["_QUBE_METADATAENTRY"]._serialized_end = 444
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Iterable
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from .node_types import NodeData
|
|
||||||
from .value_types import QEnum, ValueGroup, WildcardGroup
|
from .value_types import QEnum, ValueGroup, WildcardGroup
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -27,7 +26,7 @@ class SetOperation(Enum):
|
|||||||
@dataclass(eq=True, frozen=True)
|
@dataclass(eq=True, frozen=True)
|
||||||
class ValuesMetadata:
|
class ValuesMetadata:
|
||||||
values: ValueGroup
|
values: ValueGroup
|
||||||
metadata: dict[str, np.ndarray]
|
indices: list[int] | slice
|
||||||
|
|
||||||
|
|
||||||
def QEnum_intersection(
|
def QEnum_intersection(
|
||||||
@ -49,19 +48,17 @@ def QEnum_intersection(
|
|||||||
|
|
||||||
intersection_out = ValuesMetadata(
|
intersection_out = ValuesMetadata(
|
||||||
values=QEnum(list(intersection.keys())),
|
values=QEnum(list(intersection.keys())),
|
||||||
metadata={
|
indices=list(intersection.values()),
|
||||||
k: v[..., tuple(intersection.values())] for k, v in A.metadata.items()
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
just_A_out = ValuesMetadata(
|
just_A_out = ValuesMetadata(
|
||||||
values=QEnum(list(just_A.keys())),
|
values=QEnum(list(just_A.keys())),
|
||||||
metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()},
|
indices=list(just_A.values()),
|
||||||
)
|
)
|
||||||
|
|
||||||
just_B_out = ValuesMetadata(
|
just_B_out = ValuesMetadata(
|
||||||
values=QEnum(list(just_B.keys())),
|
values=QEnum(list(just_B.keys())),
|
||||||
metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()},
|
indices=list(just_B.values()),
|
||||||
)
|
)
|
||||||
|
|
||||||
return just_A_out, intersection_out, just_B_out
|
return just_A_out, intersection_out, just_B_out
|
||||||
@ -76,61 +73,107 @@ def node_intersection(
|
|||||||
|
|
||||||
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
|
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
|
||||||
return (
|
return (
|
||||||
ValuesMetadata(QEnum([]), {}),
|
ValuesMetadata(QEnum([]), []),
|
||||||
ValuesMetadata(WildcardGroup(), {}),
|
ValuesMetadata(WildcardGroup(), slice(None)),
|
||||||
ValuesMetadata(QEnum([]), {}),
|
ValuesMetadata(QEnum([]), []),
|
||||||
)
|
)
|
||||||
|
|
||||||
# If A is a wildcard matcher then the intersection is everything
|
# If A is a wildcard matcher then the intersection is everything
|
||||||
# just_A is still *
|
# just_A is still *
|
||||||
# just_B is empty
|
# just_B is empty
|
||||||
if isinstance(A.values, WildcardGroup):
|
if isinstance(A.values, WildcardGroup):
|
||||||
return A, B, ValuesMetadata(QEnum([]), {})
|
return A, B, ValuesMetadata(QEnum([]), [])
|
||||||
|
|
||||||
# The reverse if B is a wildcard
|
# The reverse if B is a wildcard
|
||||||
if isinstance(B.values, WildcardGroup):
|
if isinstance(B.values, WildcardGroup):
|
||||||
return ValuesMetadata(QEnum([]), {}), A, B
|
return ValuesMetadata(QEnum([]), []), A, B
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
|
f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube | None:
|
def operation(
|
||||||
|
A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0
|
||||||
|
) -> Qube | None:
|
||||||
assert A.key == B.key, (
|
assert A.key == B.key, (
|
||||||
"The two Qube root nodes must have the same key to perform set operations,"
|
"The two Qube root nodes must have the same key to perform set operations,"
|
||||||
f"would usually be two root nodes. They have {A.key} and {B.key} respectively"
|
f"would usually be two root nodes. They have {A.key} and {B.key} respectively"
|
||||||
)
|
)
|
||||||
|
node_key = A.key
|
||||||
|
|
||||||
|
assert A.is_root == B.is_root
|
||||||
|
is_root = A.is_root
|
||||||
|
|
||||||
assert A.values == B.values, (
|
assert A.values == B.values, (
|
||||||
f"The two Qube root nodes must have the same values to perform set operations {A.values = }, {B.values = }"
|
f"The two Qube root nodes must have the same values to perform set operations {A.values = }, {B.values = }"
|
||||||
)
|
)
|
||||||
|
node_values = A.values
|
||||||
|
|
||||||
# Group the children of the two nodes by key
|
# Group the children of the two nodes by key
|
||||||
nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict(
|
nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict(
|
||||||
lambda: ([], [])
|
lambda: ([], [])
|
||||||
)
|
)
|
||||||
for node in A.children:
|
|
||||||
nodes_by_key[node.key][0].append(node)
|
|
||||||
for node in B.children:
|
|
||||||
nodes_by_key[node.key][1].append(node)
|
|
||||||
|
|
||||||
new_children: list[Qube] = []
|
new_children: list[Qube] = []
|
||||||
|
|
||||||
|
# Sort out metadata into what can stay at this level and what must move down
|
||||||
|
stayput_metadata: dict[str, np.ndarray] = {}
|
||||||
|
pushdown_metadata_A: dict[str, np.ndarray] = {}
|
||||||
|
pushdown_metadata_B: dict[str, np.ndarray] = {}
|
||||||
|
for key in set(A.metadata.keys()) | set(B.metadata.keys()):
|
||||||
|
if key not in A.metadata:
|
||||||
|
raise ValueError(f"B has key {key} but A does not. {A = } {B = }")
|
||||||
|
if key not in B.metadata:
|
||||||
|
raise ValueError(f"A has key {key} but B does not. {A = } {B = }")
|
||||||
|
|
||||||
|
print(f"{key = } {A.metadata[key] = } {B.metadata[key]}")
|
||||||
|
A_val = A.metadata[key]
|
||||||
|
B_val = B.metadata[key]
|
||||||
|
if A_val == B_val:
|
||||||
|
print(f"{' ' * depth}Keeping metadata key '{key}' at this level")
|
||||||
|
stayput_metadata[key] = A.metadata[key]
|
||||||
|
else:
|
||||||
|
print(f"{' ' * depth}Pushing down metadata key '{key}' {A_val} {B_val}")
|
||||||
|
pushdown_metadata_A[key] = A_val
|
||||||
|
pushdown_metadata_B[key] = B_val
|
||||||
|
|
||||||
|
# Add all the metadata that needs to be pushed down to the child nodes
|
||||||
|
# When pushing down the metadata we need to account for the fact it now affects more values
|
||||||
|
# So expand the metadata entries from shape (a, b, ..., c) to (a, b, ..., c, d)
|
||||||
|
# where d is the length of the node values
|
||||||
|
for node in A.children:
|
||||||
|
N = len(node.values)
|
||||||
|
print(N)
|
||||||
|
meta = {
|
||||||
|
k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,))
|
||||||
|
for k, v in pushdown_metadata_A.items()
|
||||||
|
}
|
||||||
|
node = node.replace(metadata=node.metadata | meta)
|
||||||
|
nodes_by_key[node.key][0].append(node)
|
||||||
|
|
||||||
|
for node in B.children:
|
||||||
|
N = len(node.values)
|
||||||
|
meta = {
|
||||||
|
k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,))
|
||||||
|
for k, v in pushdown_metadata_B.items()
|
||||||
|
}
|
||||||
|
node = node.replace(metadata=node.metadata | meta)
|
||||||
|
nodes_by_key[node.key][1].append(node)
|
||||||
|
|
||||||
# For every node group, perform the set operation
|
# For every node group, perform the set operation
|
||||||
for key, (A_nodes, B_nodes) in nodes_by_key.items():
|
for key, (A_nodes, B_nodes) in nodes_by_key.items():
|
||||||
output = list(_operation(key, A_nodes, B_nodes, operation_type, node_type))
|
output = list(
|
||||||
|
_operation(key, A_nodes, B_nodes, operation_type, node_type, depth + 1)
|
||||||
|
)
|
||||||
|
# print(f"{' '*depth}_operation {operation_type.name} {A_nodes} {B_nodes} out = [{output}]")
|
||||||
new_children.extend(output)
|
new_children.extend(output)
|
||||||
|
|
||||||
# print(f"operation {operation_type}: {A}, {B} {new_children = }")
|
# print(f"{' '*depth}operation {operation_type.name} [{A}] [{B}] new_children = [{new_children}]")
|
||||||
# print(f"{A.children = }")
|
|
||||||
# print(f"{B.children = }")
|
|
||||||
# print(f"{new_children = }")
|
|
||||||
|
|
||||||
# If there are now no children as a result of the operation, return nothing.
|
# If there are now no children as a result of the operation, return nothing.
|
||||||
if (A.children or B.children) and not new_children:
|
if (A.children or B.children) and not new_children:
|
||||||
if A.key == "root":
|
if A.key == "root":
|
||||||
return A.replace(children=())
|
return node_type.make_root(children=())
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -140,20 +183,34 @@ def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube
|
|||||||
new_children = list(compress_children(new_children))
|
new_children = list(compress_children(new_children))
|
||||||
|
|
||||||
# The values and key are the same so we just replace the children
|
# The values and key are the same so we just replace the children
|
||||||
return A.replace(children=tuple(sorted(new_children)))
|
return node_type.make_node(
|
||||||
|
key=node_key,
|
||||||
|
values=node_values,
|
||||||
|
children=new_children,
|
||||||
|
metadata=stayput_metadata,
|
||||||
|
is_root=is_root,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_indices(metadata: dict[str, np.ndarray], indices: list[int] | slice):
|
||||||
|
return {k: v[..., indices] for k, v in metadata.items()}
|
||||||
|
|
||||||
|
|
||||||
# The root node is special so we need a helper method that we can recurse on
|
|
||||||
def _operation(
|
def _operation(
|
||||||
key: str, A: list[Qube], B: list[Qube], operation_type: SetOperation, node_type
|
key: str,
|
||||||
|
A: list[Qube],
|
||||||
|
B: list[Qube],
|
||||||
|
operation_type: SetOperation,
|
||||||
|
node_type,
|
||||||
|
depth: int,
|
||||||
) -> Iterable[Qube]:
|
) -> Iterable[Qube]:
|
||||||
keep_just_A, keep_intersection, keep_just_B = operation_type.value
|
keep_just_A, keep_intersection, keep_just_B = operation_type.value
|
||||||
|
|
||||||
# Iterate over all pairs (node_A, node_B)
|
|
||||||
values = {}
|
values = {}
|
||||||
for node in A + B:
|
for node in A + B:
|
||||||
values[node] = ValuesMetadata(node.values, node.metadata)
|
values[node] = ValuesMetadata(node.values, node.metadata)
|
||||||
|
|
||||||
|
# Iterate over all pairs (node_A, node_B)
|
||||||
for node_a in A:
|
for node_a in A:
|
||||||
for node_b in B:
|
for node_b in B:
|
||||||
# Compute A - B, A & B, B - A
|
# Compute A - B, A & B, B - A
|
||||||
@ -171,17 +228,21 @@ def _operation(
|
|||||||
if intersection.values:
|
if intersection.values:
|
||||||
new_node_a = node_a.replace(
|
new_node_a = node_a.replace(
|
||||||
values=intersection.values,
|
values=intersection.values,
|
||||||
metadata=intersection.metadata,
|
metadata=get_indices(node_a.metadata, intersection.indices),
|
||||||
)
|
)
|
||||||
new_node_b = node_b.replace(
|
new_node_b = node_b.replace(
|
||||||
values=intersection.values,
|
values=intersection.values,
|
||||||
metadata=intersection.metadata,
|
metadata=get_indices(node_b.metadata, intersection.indices),
|
||||||
)
|
)
|
||||||
# print(f"{node_a = }")
|
# print(f"{' '*depth}{node_a = }")
|
||||||
# print(f"{node_b = }")
|
# print(f"{' '*depth}{node_b = }")
|
||||||
# print(f"{intersection.values =}")
|
# print(f"{' '*depth}{intersection.values =}")
|
||||||
result = operation(
|
result = operation(
|
||||||
new_node_a, new_node_b, operation_type, node_type
|
new_node_a,
|
||||||
|
new_node_b,
|
||||||
|
operation_type,
|
||||||
|
node_type,
|
||||||
|
depth=depth + 1,
|
||||||
)
|
)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
yield result
|
yield result
|
||||||
@ -190,20 +251,20 @@ def _operation(
|
|||||||
if keep_just_A:
|
if keep_just_A:
|
||||||
for node in A:
|
for node in A:
|
||||||
if values[node].values:
|
if values[node].values:
|
||||||
yield node_type.make(
|
yield node_type.make_node(
|
||||||
key,
|
key,
|
||||||
children=node.children,
|
children=node.children,
|
||||||
values=values[node].values,
|
values=values[node].values,
|
||||||
metadata=values[node].metadata,
|
metadata=get_indices(node.metadata, values[node].indices),
|
||||||
)
|
)
|
||||||
if keep_just_B:
|
if keep_just_B:
|
||||||
for node in B:
|
for node in B:
|
||||||
if values[node].values:
|
if values[node].values:
|
||||||
yield node_type.make(
|
yield node_type.make_node(
|
||||||
key,
|
key,
|
||||||
children=node.children,
|
children=node.children,
|
||||||
values=values[node].values,
|
values=values[node].values,
|
||||||
metadata=values[node].metadata,
|
metadata=get_indices(node.metadata, values[node].indices),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -230,7 +291,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
|||||||
key = child_list[0].key
|
key = child_list[0].key
|
||||||
|
|
||||||
# Compress the children into a single node
|
# Compress the children into a single node
|
||||||
assert all(isinstance(child.data.values, QEnum) for child in child_list), (
|
assert all(isinstance(child.values, QEnum) for child in child_list), (
|
||||||
"All children must have QEnum values"
|
"All children must have QEnum values"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -241,19 +302,19 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
|
|||||||
|
|
||||||
metadata: frozendict[str, np.ndarray] = frozendict(
|
metadata: frozendict[str, np.ndarray] = frozendict(
|
||||||
{
|
{
|
||||||
k: np.concatenate(metadata_group, axis=0)
|
k: np.concatenate(metadata_group, axis=-1)
|
||||||
for k, metadata_group in metadata_groups.items()
|
for k, metadata_group in metadata_groups.items()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
node_data = NodeData(
|
|
||||||
key=key,
|
|
||||||
metadata=metadata,
|
|
||||||
values=QEnum(set(v for child in child_list for v in child.data.values)),
|
|
||||||
)
|
|
||||||
children = [cc for c in child_list for cc in c.children]
|
children = [cc for c in child_list for cc in c.children]
|
||||||
compressed_children = compress_children(children)
|
compressed_children = compress_children(children)
|
||||||
new_child = node_type(data=node_data, children=compressed_children)
|
new_child = node_type.make_node(
|
||||||
|
key=key,
|
||||||
|
metadata=metadata,
|
||||||
|
values=QEnum(set(v for child in child_list for v in child.values)),
|
||||||
|
children=compressed_children,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# If the group is size one just keep it
|
# If the group is size one just keep it
|
||||||
new_child = child_list.pop()
|
new_child = child_list.pop()
|
||||||
|
@ -2,9 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Iterable
|
from typing import TYPE_CHECKING, Callable, Iterable
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .Qube import Qube
|
from .Qube import Qube
|
||||||
@ -71,27 +70,38 @@ def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[st
|
|||||||
|
|
||||||
|
|
||||||
def summarize_node_html(
|
def summarize_node_html(
|
||||||
node: Qube, collapse=False, max_summary_length=50, **kwargs
|
node: Qube,
|
||||||
|
collapse=False,
|
||||||
|
max_summary_length=50,
|
||||||
|
info: Callable[[Qube], str] | None = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[str, Qube]:
|
) -> tuple[str, Qube]:
|
||||||
"""
|
"""
|
||||||
Extracts a summarized representation of the node while collapsing single-child paths.
|
Extracts a summarized representation of the node while collapsing single-child paths.
|
||||||
Returns the summary string and the last node in the chain that has multiple children.
|
Returns the summary string and the last node in the chain that has multiple children.
|
||||||
"""
|
"""
|
||||||
|
if info is None:
|
||||||
|
|
||||||
|
def info_func(node: Qube, /):
|
||||||
|
return (
|
||||||
|
# f"dtype: {node.dtype}\n"
|
||||||
|
f"metadata: {dict(node.metadata)}\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info_func = info
|
||||||
|
|
||||||
summaries = []
|
summaries = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
path = node.summary(**kwargs)
|
path = node.summary(**kwargs)
|
||||||
summary = path
|
summary = path
|
||||||
if "is_leaf" in node.metadata and node.metadata["is_leaf"]:
|
|
||||||
summary += " 🌿"
|
|
||||||
|
|
||||||
if len(summary) > max_summary_length:
|
if len(summary) > max_summary_length:
|
||||||
summary = summary[:max_summary_length] + "..."
|
summary = summary[:max_summary_length] + "..."
|
||||||
info = (
|
|
||||||
f"dtype: {node.dtype.__name__}\n"
|
info_string = info_func(node)
|
||||||
f"metadata: {dict((k, np.shape(v)) for k, v in node.metadata.items())}\n"
|
|
||||||
)
|
summary = f'<span class="qubed-node" data-path="{path}" title="{info_string}">{summary}</span>'
|
||||||
summary = f'<span class="qubed-node" data-path="{path}" title="{info}">{summary}</span>'
|
|
||||||
summaries.append(summary)
|
summaries.append(summary)
|
||||||
if not collapse:
|
if not collapse:
|
||||||
break
|
break
|
||||||
@ -105,9 +115,14 @@ def summarize_node_html(
|
|||||||
|
|
||||||
|
|
||||||
def _node_tree_to_html(
|
def _node_tree_to_html(
|
||||||
node: Qube, prefix: str = "", depth=1, connector="", **kwargs
|
node: Qube,
|
||||||
|
prefix: str = "",
|
||||||
|
depth=1,
|
||||||
|
connector="",
|
||||||
|
info: Callable[[Qube], str] | None = None,
|
||||||
|
**kwargs,
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
summary, node = summarize_node_html(node, **kwargs)
|
summary, node = summarize_node_html(node, info=info, **kwargs)
|
||||||
|
|
||||||
if len(node.children) == 0:
|
if len(node.children) == 0:
|
||||||
yield f'<span class="qubed-level">{connector}{summary}</span>'
|
yield f'<span class="qubed-level">{connector}{summary}</span>'
|
||||||
@ -124,13 +139,20 @@ def _node_tree_to_html(
|
|||||||
prefix + extension,
|
prefix + extension,
|
||||||
depth=depth - 1,
|
depth=depth - 1,
|
||||||
connector=prefix + connector,
|
connector=prefix + connector,
|
||||||
|
info=info,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
yield "</details>"
|
yield "</details>"
|
||||||
|
|
||||||
|
|
||||||
def node_tree_to_html(
|
def node_tree_to_html(
|
||||||
node: Qube, depth=1, include_css=True, include_js=True, css_id=None, **kwargs
|
node: Qube,
|
||||||
|
depth=1,
|
||||||
|
include_css=True,
|
||||||
|
include_js=True,
|
||||||
|
css_id=None,
|
||||||
|
info: Callable[[Qube], str] | None = None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
if css_id is None:
|
if css_id is None:
|
||||||
css_id = f"qubed-tree-{random.randint(0, 1000000)}"
|
css_id = f"qubed-tree-{random.randint(0, 1000000)}"
|
||||||
@ -215,5 +237,5 @@ def node_tree_to_html(
|
|||||||
nodes.forEach(n => n.addEventListener("click", nodeOnClick));
|
nodes.forEach(n => n.addEventListener("click", nodeOnClick));
|
||||||
</script>
|
</script>
|
||||||
""".replace("CSS_ID", css_id)
|
""".replace("CSS_ID", css_id)
|
||||||
nodes = "".join(_node_tree_to_html(node=node, depth=depth, **kwargs))
|
nodes = "".join(_node_tree_to_html(node=node, depth=depth, info=info, **kwargs))
|
||||||
return f"{js if include_js else ''}{css if include_css else ''}<pre class='qubed-tree' id='{css_id}'>{nodes}</pre>"
|
return f"{js if include_js else ''}{css if include_css else ''}<pre class='qubed-tree' id='{css_id}'>{nodes}</pre>"
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@ -21,6 +21,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ValueGroup(ABC):
|
class ValueGroup(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def dtype(self) -> str:
|
||||||
|
"Provide a string rep of the datatype of these values"
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def summary(self) -> str:
|
def summary(self) -> str:
|
||||||
"Provide a string summary of the value group."
|
"Provide a string summary of the value group."
|
||||||
@ -69,9 +74,13 @@ class QEnum(ValueGroup):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
values: EnumValuesType
|
values: EnumValuesType
|
||||||
|
_dtype: str = "str"
|
||||||
|
|
||||||
def __init__(self, obj):
|
def __init__(self, obj):
|
||||||
object.__setattr__(self, "values", tuple(sorted(obj)))
|
object.__setattr__(self, "values", tuple(sorted(obj)))
|
||||||
|
object.__setattr__(
|
||||||
|
self, "dtype", type(self.values[0]) if len(self.values) > 0 else "str"
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert isinstance(self.values, tuple)
|
assert isinstance(self.values, tuple)
|
||||||
@ -88,6 +97,9 @@ class QEnum(ValueGroup):
|
|||||||
def __contains__(self, value: Any) -> bool:
|
def __contains__(self, value: Any) -> bool:
|
||||||
return value in self.values
|
return value in self.values
|
||||||
|
|
||||||
|
def dtype(self):
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
|
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
|
||||||
return [cls(tuple(values))]
|
return [cls(tuple(values))]
|
||||||
@ -114,7 +126,7 @@ class WildcardGroup(ValueGroup):
|
|||||||
return "*"
|
return "*"
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return None
|
return 1
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return ["*"]
|
return ["*"]
|
||||||
@ -122,6 +134,9 @@ class WildcardGroup(ValueGroup):
|
|||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def dtype(self):
|
||||||
|
return "*"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
|
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
|
||||||
return [WildcardGroup()]
|
return [WildcardGroup()]
|
||||||
@ -398,7 +413,7 @@ def convert_datatypes(q: "Qube", conversions: dict[str, ValueGroup]) -> "Qube":
|
|||||||
)
|
)
|
||||||
for values_group in data_type.from_strings(q.values):
|
for values_group in data_type.from_strings(q.values):
|
||||||
# print(values_group)
|
# print(values_group)
|
||||||
yield replace(q, data=replace(q.data, values=values_group))
|
yield q.replace(values=values_group)
|
||||||
else:
|
else:
|
||||||
yield q
|
yield q
|
||||||
|
|
||||||
|
32
src/qube.proto
Normal file
32
src/qube.proto
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message NdArray {
|
||||||
|
repeated int64 shape = 1;
|
||||||
|
string dtype = 2;
|
||||||
|
bytes raw = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message StringGroup {repeated string items = 1; }
|
||||||
|
|
||||||
|
// Stores values i.e class=1/2/3 the 1/2/3 part
|
||||||
|
message ValueGroup {
|
||||||
|
oneof payload {
|
||||||
|
StringGroup s = 1;
|
||||||
|
NdArray tensor = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message MetadataGroup {
|
||||||
|
oneof payload {
|
||||||
|
NdArray tensor = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message Qube {
|
||||||
|
string key = 1;
|
||||||
|
ValueGroup values = 2;
|
||||||
|
map<string, MetadataGroup> metadata = 3;
|
||||||
|
string dtype = 4;
|
||||||
|
repeated Qube children = 5;
|
||||||
|
bool is_root = 6;
|
||||||
|
}
|
@ -176,7 +176,7 @@ def follow_query(request: dict[str, str | list[str]], qube: Qube):
|
|||||||
by_path = defaultdict(lambda: {"paths": set(), "values": set()})
|
by_path = defaultdict(lambda: {"paths": set(), "values": set()})
|
||||||
|
|
||||||
for request, node in s.leaf_nodes():
|
for request, node in s.leaf_nodes():
|
||||||
if not node.data.metadata["is_leaf"]:
|
if not node.metadata["is_leaf"]:
|
||||||
by_path[node.key]["values"].update(node.values.values)
|
by_path[node.key]["values"].update(node.values.values)
|
||||||
by_path[node.key]["paths"].add(frozendict(request))
|
by_path[node.key]["paths"].add(frozendict(request))
|
||||||
|
|
||||||
|
@ -1,36 +1,40 @@
|
|||||||
from qubed import Qube
|
from qubed import Qube
|
||||||
|
|
||||||
d = {
|
q = Qube.from_tree("""
|
||||||
"class=od": {
|
root
|
||||||
"expver=0001": {"param=1": {}, "param=2": {}},
|
├── class=od
|
||||||
"expver=0002": {"param=1": {}, "param=2": {}},
|
│ ├── expver=0001
|
||||||
},
|
│ │ ├── param=1
|
||||||
"class=rd": {
|
│ │ └── param=2
|
||||||
"expver=0001": {"param=1": {}, "param=2": {}, "param=3": {}},
|
│ └── expver=0002
|
||||||
"expver=0002": {"param=1": {}, "param=2": {}},
|
│ ├── param=1
|
||||||
},
|
│ └── param=2
|
||||||
}
|
└── class=rd
|
||||||
q = Qube.from_dict(d)
|
├── expver=0001
|
||||||
|
│ ├── param=1
|
||||||
|
│ ├── param=2
|
||||||
def test_eq():
|
│ └── param=3
|
||||||
r = Qube.from_dict(d)
|
└── expver=0002
|
||||||
assert q == r
|
├── param=1
|
||||||
|
└── param=2
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
def test_getitem():
|
def test_getitem():
|
||||||
assert q["class", "od"] == Qube.from_dict(
|
assert q["class", "od"] == Qube.from_tree("""
|
||||||
{
|
root
|
||||||
"expver=0001": {"param=1": {}, "param=2": {}},
|
├── expver=0001
|
||||||
"expver=0002": {"param=1": {}, "param=2": {}},
|
│ ├── param=1
|
||||||
}
|
│ └── param=2
|
||||||
)
|
└── expver=0002
|
||||||
assert q["class", "od"]["expver", "0001"] == Qube.from_dict(
|
├── param=1
|
||||||
{
|
└── param=2
|
||||||
"param=1": {},
|
""")
|
||||||
"param=2": {},
|
|
||||||
}
|
assert q["class", "od"]["expver", "0001"] == Qube.from_tree("""
|
||||||
)
|
root
|
||||||
|
├── param=1
|
||||||
|
└── param=2""")
|
||||||
|
|
||||||
|
|
||||||
def test_n_leaves():
|
def test_n_leaves():
|
||||||
|
@ -2,7 +2,7 @@ from qubed import Qube
|
|||||||
|
|
||||||
|
|
||||||
def test_json_round_trip():
|
def test_json_round_trip():
|
||||||
u = Qube.from_dict(
|
from_dict = Qube.from_dict(
|
||||||
{
|
{
|
||||||
"class=d1": {
|
"class=d1": {
|
||||||
"dataset=climate-dt/weather-dt": {
|
"dataset=climate-dt/weather-dt": {
|
||||||
@ -14,5 +14,54 @@ def test_json_round_trip():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
json = u.to_json()
|
|
||||||
assert Qube.from_json(json) == u
|
from_tree = Qube.from_tree("""
|
||||||
|
root, class=d1
|
||||||
|
├── dataset=another-value, generation=1/2/3
|
||||||
|
└── dataset=climate-dt/weather-dt, generation=1/2/3/4
|
||||||
|
""")
|
||||||
|
|
||||||
|
from_json = Qube.from_json(
|
||||||
|
{
|
||||||
|
"key": "root",
|
||||||
|
"values": ["root"],
|
||||||
|
"metadata": {},
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"key": "class",
|
||||||
|
"values": ["d1"],
|
||||||
|
"metadata": {},
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"key": "dataset",
|
||||||
|
"values": ["another-value"],
|
||||||
|
"metadata": {},
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"key": "generation",
|
||||||
|
"values": ["1", "2", "3"],
|
||||||
|
"metadata": {},
|
||||||
|
"children": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "dataset",
|
||||||
|
"values": ["climate-dt", "weather-dt"],
|
||||||
|
"metadata": {},
|
||||||
|
"children": [
|
||||||
|
{
|
||||||
|
"key": "generation",
|
||||||
|
"values": ["1", "2", "3", "4"],
|
||||||
|
"metadata": {},
|
||||||
|
"children": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert from_tree == from_json
|
||||||
|
assert from_tree == from_dict
|
||||||
|
@ -20,14 +20,6 @@ root
|
|||||||
└── expver=0002, param=1/2
|
└── expver=0002, param=1/2
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
as_html = """
|
|
||||||
<details open><summary class="qubed-level"><span class="qubed-node" data-path="root" title="dtype: str\nmetadata: {}\n">root</span></summary><span class="qubed-level">├── <span class="qubed-node" data-path="class=od" title="dtype: str\nmetadata: {}\n">class=od</span>, <span class="qubed-node" data-path="expver=0001/0002" title="dtype: str\nmetadata: {}\n">expver=0001/0002</span>, <span class="qubed-node" data-path="param=1/2" title="dtype: str\nmetadata: {}\n">param=1/2</span></span><details open><summary class="qubed-level">└── <span class="qubed-node" data-path="class=rd" title="dtype: str\nmetadata: {}\n">class=rd</span></summary><span class="qubed-level"> ├── <span class="qubed-node" data-path="expver=0001" title="dtype: str\nmetadata: {}\n">expver=0001</span>, <span class="qubed-node" data-path="param=1/2/3" title="dtype: str\nmetadata: {}\n">param=1/2/3</span></span><span class="qubed-level"> └── <span class="qubed-node" data-path="expver=0002" title="dtype: str\nmetadata: {}\n">expver=0002</span>, <span class="qubed-node" data-path="param=1/2" title="dtype: str\nmetadata: {}\n">param=1/2</span></span></details></details>
|
|
||||||
""".strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_string():
|
def test_string():
|
||||||
assert str(q).strip() == as_string
|
assert str(q).strip() == as_string
|
||||||
|
|
||||||
|
|
||||||
def test_html():
|
|
||||||
assert as_html in q._repr_html_()
|
|
||||||
|
@ -1,45 +1,44 @@
|
|||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
from qubed import Qube
|
|
||||||
|
|
||||||
|
|
||||||
def make_set(entries):
|
def make_set(entries):
|
||||||
return set((frozendict(a), frozendict(b)) for a, b in entries)
|
return set((frozendict(a), frozendict(b)) for a, b in entries)
|
||||||
|
|
||||||
|
|
||||||
def test_simple_union():
|
# def test_simple_union():
|
||||||
q = Qube.from_nodes(
|
# q = Qube.from_nodes(
|
||||||
{
|
# {
|
||||||
"class": dict(values=["od", "rd"]),
|
# "class": dict(values=["od", "rd"]),
|
||||||
"expver": dict(values=[1, 2]),
|
# "expver": dict(values=[1, 2]),
|
||||||
"stream": dict(
|
# "stream": dict(
|
||||||
values=["a", "b", "c"], metadata=dict(number=list(range(12)))
|
# values=["a", "b", "c"], metadata=dict(number=list(range(12)))
|
||||||
),
|
# ),
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
r = Qube.from_nodes(
|
# r = Qube.from_nodes(
|
||||||
{
|
# {
|
||||||
"class": dict(values=["xd"]),
|
# "class": dict(values=["xd"]),
|
||||||
"expver": dict(values=[1, 2]),
|
# "expver": dict(values=[1, 2]),
|
||||||
"stream": dict(
|
# "stream": dict(
|
||||||
values=["a", "b", "c"], metadata=dict(number=list(range(12, 18)))
|
# values=["a", "b", "c"], metadata=dict(number=list(range(12, 18)))
|
||||||
),
|
# ),
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
expected_union = Qube.from_nodes(
|
# expected_union = Qube.from_nodes(
|
||||||
{
|
# {
|
||||||
"class": dict(values=["od", "rd", "xd"]),
|
# "class": dict(values=["od", "rd", "xd"]),
|
||||||
"expver": dict(values=[1, 2]),
|
# "expver": dict(values=[1, 2]),
|
||||||
"stream": dict(
|
# "stream": dict(
|
||||||
values=["a", "b", "c"], metadata=dict(number=list(range(18)))
|
# values=["a", "b", "c"], metadata=dict(number=list(range(18)))
|
||||||
),
|
# ),
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
union = q | r
|
# union = q | r
|
||||||
|
|
||||||
assert union == expected_union
|
# assert union == expected_union
|
||||||
assert make_set(expected_union.leaves_with_metadata()) == make_set(
|
# assert make_set(expected_union.leaves_with_metadata()) == make_set(
|
||||||
union.leaves_with_metadata()
|
# union.leaves_with_metadata()
|
||||||
)
|
# )
|
||||||
|
12
tests/test_protobuf.py
Normal file
12
tests/test_protobuf.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from qubed import Qube
|
||||||
|
|
||||||
|
|
||||||
|
def test_protobuf_simple():
|
||||||
|
q = Qube.from_tree("""
|
||||||
|
root, class=d1
|
||||||
|
├── dataset=another-value, generation=1/2/3
|
||||||
|
└── dataset=climate-dt/weather-dt, generation=1/2/3/4
|
||||||
|
""")
|
||||||
|
wire = q.to_protobuf()
|
||||||
|
round_trip = Qube.from_protobuf(wire)
|
||||||
|
assert round_trip == q
|
Loading…
x
Reference in New Issue
Block a user