diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index b9fb503..79a004e 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -8,17 +8,17 @@ import functools import json from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property from pathlib import Path -from typing import Any, Iterable, Iterator, Literal, Sequence +from typing import Any, Iterable, Iterator, Literal, Self, Sequence import numpy as np from frozendict import frozendict from . import set_operations from .metadata import from_nodes -from .node_types import NodeData, RootNodeData +from .protobuf.adapters import proto_to_qube, qube_to_proto from .tree_formatters import ( HTML, node_tree_to_html, @@ -32,58 +32,90 @@ from .value_types import ( ) -@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True) -class Qube: - data: NodeData - children: tuple[Qube, ...] +@dataclass +class AxisInfo: + key: str + type: Any + depths: set[int] + values: set - @property - def key(self) -> str: - return self.data.key + def combine(self, other: Self): + self.key = other.key + self.type = other.type + self.depths.update(other.depths) + self.values.update(other.values) + # print(f"combining {self} and {other} getting {result}") - @property - def values(self) -> ValueGroup: - return self.data.values - - @property - def metadata(self): - return self.data.metadata - - @property - def dtype(self): - return self.data.dtype - - def replace(self, **kwargs) -> Qube: - data_keys = { - k: v - for k, v in kwargs.items() - if k in ["key", "values", "metadata", "dtype"] + def to_json(self): + return { + "key": self.key, + "type": self.type.__name__, + "values": list(self.values), + "depths": list(self.depths), } - node_keys = {k: v for k, v in kwargs.items() if k == "children"} - if not data_keys and not node_keys: - return self - if not data_keys: - return dataclasses.replace(self, **node_keys) - return dataclasses.replace( - self, data=dataclasses.replace(self.data, **data_keys), **node_keys - ) + +@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True) +class QubeNamedRoot: + "Helper class to print a custom root name" + + key: str + dtype: str = "str" + children: tuple[Qube, ...] = () def summary(self) -> str: - return self.data.summary() + return self.key + + +@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True) +class Qube: + key: str + values: ValueGroup + metadata: frozendict[str, np.ndarray] = field( + default_factory=lambda: frozendict({}), compare=False + ) + children: tuple[Qube, ...] = () + is_root: bool = False + + def replace(self, **kwargs) -> Qube: + return dataclasses.replace(self, **kwargs) + + def summary(self) -> str: + if self.is_root: + return self.key + return f"{self.key}={self.values.summary()}" if self.key != "root" else "root" @classmethod - def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube: + def make_node( + cls, + key: str, + values: Iterable | QEnum | WildcardGroup, + children: Iterable[Qube], + metadata: dict[str, np.ndarray] = {}, + is_root: bool = False, + ) -> Qube: + if isinstance(values, ValueGroup): + values = values + else: + values = QEnum(values) + return cls( - data=NodeData( - key, values, metadata=frozendict(kwargs.get("metadata", frozendict())) - ), + key, + values=values, children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))), + metadata=frozendict(metadata), + is_root=is_root, ) @classmethod - def root_node(cls, children: Iterable[Qube]) -> Qube: - return cls.make("root", QEnum(("root",)), children) + def make_root(cls, children: Iterable[Qube], metadata={}) -> Qube: + return cls.make_node( + "root", + values=QEnum(("root",)), + children=children, + metadata=metadata, + is_root=True, + ) @classmethod def load(cls, path: str | Path) -> Qube: @@ -104,18 +136,19 @@ class Qube: else: values_group = QEnum([values]) - children = [cls.make(key, values_group, children)] + children = [cls.make_node(key, values_group, children)] - return cls.root_node(children) + return cls.make_root(children) @classmethod def from_json(cls, json: dict) -> Qube: - def from_json(json: dict) -> Qube: - return Qube.make( + def from_json(json: dict, depth=0) -> Qube: + return Qube.make_node( key=json["key"], values=values_from_json(json["values"]), metadata=frozendict(json["metadata"]) if "metadata" in json else {}, - children=(from_json(c) for c in json["children"]), + children=(from_json(c, depth + 1) for c in json["children"]), + is_root=(depth == 0), ) return from_json(json) @@ -146,13 +179,13 @@ class Qube: else: values = QEnum(values) - yield Qube.make( + yield Qube.make_node( key=key, values=values, children=from_dict(children), ) - return Qube.root_node(list(from_dict(d))) + return Qube.make_root(list(from_dict(d))) def to_dict(self) -> dict: def to_dict(q: Qube) -> tuple[str, dict]: @@ -161,6 +194,13 @@ class Qube: return to_dict(self)[1] + @classmethod + def from_protobuf(cls, msg: bytes) -> Qube: + return proto_to_qube(cls, msg) + + def to_protobuf(self) -> bytes: + return qube_to_proto(self) + @classmethod def from_tree(cls, tree_str): lines = tree_str.splitlines() @@ -214,17 +254,12 @@ class Qube: @classmethod def empty(cls) -> Qube: - return Qube.root_node([]) + return Qube.make_root([]) def __str_helper__(self, depth=None, name=None) -> str: - node = ( - dataclasses.replace( - self, - data=RootNodeData(key=name, values=self.values, metadata=self.metadata), - ) - if name is not None - else self - ) + node = self + if name is not None: + node = node.replace(key=name) out = "".join(node_tree_to_string(node=node, depth=depth)) if out[-1] == "\n": out = out[:-1] @@ -239,16 +274,19 @@ class Qube: def print(self, depth=None, name: str | None = None): print(self.__str_helper__(depth=depth, name=name)) - def html(self, depth=2, collapse=True, name: str | None = None) -> HTML: - node = ( - dataclasses.replace( - self, - data=RootNodeData(key=name, values=self.values, metadata=self.metadata), - ) - if name is not None - else self + def html( + self, + depth=2, + collapse=True, + name: str | None = None, + info: Callable[[Qube], str] | None = None, + ) -> HTML: + node = self + if name is not None: + node = node.replace(key=name) + return HTML( + node_tree_to_html(node=node, depth=depth, collapse=collapse, info=info) ) - return HTML(node_tree_to_html(node=node, depth=depth, collapse=collapse)) def _repr_html_(self) -> str: return node_tree_to_html(self, depth=2, collapse=True) @@ -257,7 +295,7 @@ class Qube: def __rtruediv__(self, other: str) -> Qube: key, values = other.split("=") values_enum = QEnum((values.split("/"))) - return Qube.root_node([Qube.make(key, values_enum, self.children)]) + return Qube.make_root([Qube.make_node(key, values_enum, self.children)]) def __or__(self, other: Qube) -> Qube: return set_operations.operation( @@ -358,16 +396,16 @@ class Qube: raise KeyError( f"Key '{key}' not found in children of '{current.key}', available keys are {[c.key for c in current.children]}" ) - return Qube.root_node(current.children) + return Qube.make_root(current.children) elif isinstance(args, tuple) and len(args) == 2: key, value = args for c in self.children: if c.key == key and value in c.values: - return Qube.root_node(c.children) - raise KeyError(f"Key {key} not found in children of {self.key}") + return Qube.make_root(c.children) + raise KeyError(f"Key '{key}' not found in children of {self.key}") else: - raise ValueError("Unknown key type") + raise ValueError(f"Unknown key type {args}") @cached_property def n_leaves(self) -> int: @@ -410,7 +448,7 @@ class Qube: for c in node.children: if c.key in _keys: grandchildren = tuple(sorted(remove_key(cc) for cc in c.children)) - grandchildren = remove_key(Qube.root_node(grandchildren)).children + grandchildren = remove_key(Qube.make_root(grandchildren)).children children.extend(grandchildren) else: children.append(remove_key(c)) @@ -424,7 +462,7 @@ class Qube: if node.key in converters: converter = converters[node.key] values = [converter(v) for v in node.values] - new_node = node.replace(values=QEnum(values), dtype=type(values[0])) + new_node = node.replace(values=QEnum(values)) return new_node return node @@ -516,7 +554,8 @@ class Qube: return node.replace( children=new_children, - metadata=dict(self.metadata) | {"is_leaf": not bool(new_children)}, + metadata=dict(self.metadata) + | ({"is_leaf": not bool(new_children)} if mode == "next_level" else {}), ) return self.replace( @@ -544,6 +583,26 @@ class Qube: axes[self.key].update(self.values) return dict(axes) + def axes_info(self, depth=0) -> dict[str, AxisInfo]: + axes = defaultdict( + lambda: AxisInfo(key="", type=str, depths=set(), values=set()) + ) + for c in self.children: + for k, info in c.axes_info(depth=depth + 1).items(): + axes[k].combine(info) + + if self.key != "root": + axes[self.key].combine( + AxisInfo( + key=self.key, + type=type(next(iter(self.values))), + depths={depth}, + values=set(self.values), + ) + ) + + return dict(axes) + @cached_property def structural_hash(self) -> int: """ @@ -570,7 +629,7 @@ class Qube: """ def union(a: Qube, b: Qube) -> Qube: - b = type(self).root_node(children=(b,)) + b = type(self).make_root(children=(b,)) out = set_operations.operation( a, b, set_operations.SetOperation.UNION, type(self) ) @@ -583,3 +642,20 @@ class Qube: ) return self.replace(children=tuple(sorted(new_children))) + + def add_metadata(self, **kwargs: dict[str, Any]): + metadata = { + k: np.array( + [ + v, + ] + ) + for k, v in kwargs.items() + } + return self.replace(metadata=metadata) + + def strip_metadata(self) -> Qube: + def strip(node): + return node.replace(metadata=frozendict({})) + + return self.transform(strip) diff --git a/src/python/qubed/__init__.py b/src/python/qubed/__init__.py index 399752b..00fc3a1 100644 --- a/src/python/qubed/__init__.py +++ b/src/python/qubed/__init__.py @@ -1,3 +1,4 @@ +from . import protobuf from .Qube import Qube -__all__ = ["Qube"] +__all__ = ["Qube", "protobuf"] diff --git a/src/python/qubed/metadata.py b/src/python/qubed/metadata.py index db460d5..05e37ee 100644 --- a/src/python/qubed/metadata.py +++ b/src/python/qubed/metadata.py @@ -18,7 +18,7 @@ def make_node( children: tuple[Qube, ...], metadata: dict[str, np.ndarray] | None = None, ): - return cls.make( + return cls.make_node( key=key, values=QEnum(values), metadata={k: np.array(v).reshape(shape) for k, v in metadata.items()} @@ -39,5 +39,5 @@ def from_nodes(cls, nodes, add_root=True): root = make_node(cls, shape=shape, children=(root,), key=key, **info) if add_root: - return cls.root_node(children=(root,)) + return cls.make_root(children=(root,)) return root diff --git a/src/python/qubed/node_types.py b/src/python/qubed/node_types.py deleted file mode 100644 index 563d813..0000000 --- a/src/python/qubed/node_types.py +++ /dev/null @@ -1,27 +0,0 @@ -from dataclasses import dataclass, field - -import numpy as np -from frozendict import frozendict - -from .value_types import ValueGroup - - -@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True) -class NodeData: - key: str - values: ValueGroup - metadata: frozendict[str, np.ndarray] = field( - default_factory=lambda: frozendict({}), compare=False - ) - dtype: type = str - - def summary(self) -> str: - return f"{self.key}={self.values.summary()}" if self.key != "root" else "root" - - -@dataclass(frozen=False, eq=True, order=True) -class RootNodeData(NodeData): - "Helper class to print a custom root name" - - def summary(self) -> str: - return self.key diff --git a/src/python/qubed/protobuf/__init__.py b/src/python/qubed/protobuf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/python/qubed/protobuf/adapters.py b/src/python/qubed/protobuf/adapters.py new file mode 100644 index 0000000..4d89baa --- /dev/null +++ b/src/python/qubed/protobuf/adapters.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from frozendict import frozendict + +from ..value_types import QEnum +from . import qube_pb2 + +if TYPE_CHECKING: + from ..Qube import Qube + + +def _ndarray_to_proto(arr: np.ndarray) -> qube_pb2.NdArray: + """np.ndarray → NdArray message""" + return qube_pb2.NdArray( + shape=list(arr.shape), + dtype=str(arr.dtype), + raw=arr.tobytes(order="C"), + ) + + +def _ndarray_from_proto(msg: qube_pb2.NdArray) -> np.ndarray: + """NdArray message → np.ndarray (immutable view)""" + return np.frombuffer(msg.raw, dtype=msg.dtype).reshape(tuple(msg.shape)) + + +def _py_to_valuegroup(value: list[str] | np.ndarray) -> qube_pb2.ValueGroup: + """Accept str-sequence *or* ndarray and return ValueGroup.""" + vg = qube_pb2.ValueGroup() + if isinstance(value, np.ndarray): + vg.tensor.CopyFrom(_ndarray_to_proto(value)) + else: + vg.s.items.extend(value) + return vg + + +def _valuegroup_to_py(vg: qube_pb2.ValueGroup) -> list[str] | np.ndarray: + """ValueGroup → list[str] *or* ndarray""" + arm = vg.WhichOneof("payload") + if arm == "tensor": + return _ndarray_from_proto(vg.tensor) + + return QEnum(vg.s.items) + + +def _py_to_metadatagroup(value: np.ndarray) -> qube_pb2.MetadataGroup: + """Accept str-sequence *or* ndarray and return ValueGroup.""" + vg = qube_pb2.MetadataGroup() + if not isinstance(value, np.ndarray): + value = np.array([value]) + + vg.tensor.CopyFrom(_ndarray_to_proto(value)) + return vg + + +def _metadatagroup_to_py(vg: qube_pb2.MetadataGroup) -> np.ndarray: + """ValueGroup → list[str] *or* ndarray""" + arm = vg.WhichOneof("payload") + if arm == "tensor": + return _ndarray_from_proto(vg.tensor) + + raise ValueError(f"Unknown arm {arm}") + + +def _qube_to_proto(q: Qube) -> qube_pb2.Qube: + """Frozen Qube dataclass → protobuf Qube message (new object).""" + return qube_pb2.Qube( + key=q.key, + values=_py_to_valuegroup(q.values), + metadata={k: _py_to_metadatagroup(v) for k, v in q.metadata.items()}, + children=[_qube_to_proto(c) for c in q.children], + is_root=q.is_root, + ) + + +def qube_to_proto(q: Qube) -> bytes: + return _qube_to_proto(q).SerializeToString() + + +def _proto_to_qube(cls: type, msg: qube_pb2.Qube) -> Qube: + """protobuf Qube message → frozen Qube dataclass (new object).""" + + return cls( + key=msg.key, + values=_valuegroup_to_py(msg.values), + metadata=frozendict( + {k: _metadatagroup_to_py(v) for k, v in msg.metadata.items()} + ), + children=tuple(_proto_to_qube(cls, c) for c in msg.children), + is_root=msg.is_root, + ) + + +def proto_to_qube(cls: type, wire: bytes) -> Qube: + msg = qube_pb2.Qube() + msg.ParseFromString(wire) + return _proto_to_qube(cls, msg) diff --git a/src/python/qubed/protobuf/qube_pb2.py b/src/python/qubed/protobuf/qube_pb2.py new file mode 100644 index 0000000..6a5ea5c --- /dev/null +++ b/src/python/qubed/protobuf/qube_pb2.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: qube.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, 5, 29, 0, "", "qube.proto" +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\nqube.proto"4\n\x07NdArray\x12\r\n\x05shape\x18\x01 \x03(\x03\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x0b\n\x03raw\x18\x03 \x01(\x0c"\x1c\n\x0bStringGroup\x12\r\n\x05items\x18\x01 \x03(\t"N\n\nValueGroup\x12\x19\n\x01s\x18\x01 \x01(\x0b\x32\x0c.StringGroupH\x00\x12\x1a\n\x06tensor\x18\x02 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"6\n\rMetadataGroup\x12\x1a\n\x06tensor\x18\x01 \x01(\x0b\x32\x08.NdArrayH\x00\x42\t\n\x07payload"\xd1\x01\n\x04Qube\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1b\n\x06values\x18\x02 \x01(\x0b\x32\x0b.ValueGroup\x12%\n\x08metadata\x18\x03 \x03(\x0b\x32\x13.Qube.MetadataEntry\x12\r\n\x05\x64type\x18\x04 \x01(\t\x12\x17\n\x08\x63hildren\x18\x05 \x03(\x0b\x32\x05.Qube\x12\x0f\n\x07is_root\x18\x06 \x01(\x08\x1a?\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x1d\n\x05value\x18\x02 \x01(\x0b\x32\x0e.MetadataGroup:\x02\x38\x01\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "qube_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals["_QUBE_METADATAENTRY"]._loaded_options = None + _globals["_QUBE_METADATAENTRY"]._serialized_options = b"8\001" + _globals["_NDARRAY"]._serialized_start = 14 + _globals["_NDARRAY"]._serialized_end = 66 + _globals["_STRINGGROUP"]._serialized_start = 68 + _globals["_STRINGGROUP"]._serialized_end = 96 + _globals["_VALUEGROUP"]._serialized_start = 98 + _globals["_VALUEGROUP"]._serialized_end = 176 + _globals["_METADATAGROUP"]._serialized_start = 178 + _globals["_METADATAGROUP"]._serialized_end = 232 + _globals["_QUBE"]._serialized_start = 235 + _globals["_QUBE"]._serialized_end = 444 + _globals["_QUBE_METADATAENTRY"]._serialized_start = 381 + _globals["_QUBE_METADATAENTRY"]._serialized_end = 444 +# @@protoc_insertion_point(module_scope) diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index 84f0b6e..ecb2ad5 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Iterable import numpy as np from frozendict import frozendict -from .node_types import NodeData from .value_types import QEnum, ValueGroup, WildcardGroup if TYPE_CHECKING: @@ -27,7 +26,7 @@ class SetOperation(Enum): @dataclass(eq=True, frozen=True) class ValuesMetadata: values: ValueGroup - metadata: dict[str, np.ndarray] + indices: list[int] | slice def QEnum_intersection( @@ -49,19 +48,17 @@ def QEnum_intersection( intersection_out = ValuesMetadata( values=QEnum(list(intersection.keys())), - metadata={ - k: v[..., tuple(intersection.values())] for k, v in A.metadata.items() - }, + indices=list(intersection.values()), ) just_A_out = ValuesMetadata( values=QEnum(list(just_A.keys())), - metadata={k: v[..., tuple(just_A.values())] for k, v in A.metadata.items()}, + indices=list(just_A.values()), ) just_B_out = ValuesMetadata( values=QEnum(list(just_B.keys())), - metadata={k: v[..., tuple(just_B.values())] for k, v in B.metadata.items()}, + indices=list(just_B.values()), ) return just_A_out, intersection_out, just_B_out @@ -76,61 +73,107 @@ def node_intersection( if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup): return ( - ValuesMetadata(QEnum([]), {}), - ValuesMetadata(WildcardGroup(), {}), - ValuesMetadata(QEnum([]), {}), + ValuesMetadata(QEnum([]), []), + ValuesMetadata(WildcardGroup(), slice(None)), + ValuesMetadata(QEnum([]), []), ) # If A is a wildcard matcher then the intersection is everything # just_A is still * # just_B is empty if isinstance(A.values, WildcardGroup): - return A, B, ValuesMetadata(QEnum([]), {}) + return A, B, ValuesMetadata(QEnum([]), []) # The reverse if B is a wildcard if isinstance(B.values, WildcardGroup): - return ValuesMetadata(QEnum([]), {}), A, B + return ValuesMetadata(QEnum([]), []), A, B raise NotImplementedError( f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented" ) -def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube | None: +def operation( + A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0 +) -> Qube | None: assert A.key == B.key, ( "The two Qube root nodes must have the same key to perform set operations," f"would usually be two root nodes. They have {A.key} and {B.key} respectively" ) + node_key = A.key + + assert A.is_root == B.is_root + is_root = A.is_root assert A.values == B.values, ( f"The two Qube root nodes must have the same values to perform set operations {A.values = }, {B.values = }" ) + node_values = A.values # Group the children of the two nodes by key nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict( lambda: ([], []) ) - for node in A.children: - nodes_by_key[node.key][0].append(node) - for node in B.children: - nodes_by_key[node.key][1].append(node) - new_children: list[Qube] = [] + # Sort out metadata into what can stay at this level and what must move down + stayput_metadata: dict[str, np.ndarray] = {} + pushdown_metadata_A: dict[str, np.ndarray] = {} + pushdown_metadata_B: dict[str, np.ndarray] = {} + for key in set(A.metadata.keys()) | set(B.metadata.keys()): + if key not in A.metadata: + raise ValueError(f"B has key {key} but A does not. {A = } {B = }") + if key not in B.metadata: + raise ValueError(f"A has key {key} but B does not. {A = } {B = }") + + print(f"{key = } {A.metadata[key] = } {B.metadata[key]}") + A_val = A.metadata[key] + B_val = B.metadata[key] + if A_val == B_val: + print(f"{' ' * depth}Keeping metadata key '{key}' at this level") + stayput_metadata[key] = A.metadata[key] + else: + print(f"{' ' * depth}Pushing down metadata key '{key}' {A_val} {B_val}") + pushdown_metadata_A[key] = A_val + pushdown_metadata_B[key] = B_val + + # Add all the metadata that needs to be pushed down to the child nodes + # When pushing down the metadata we need to account for the fact it now affects more values + # So expand the metadata entries from shape (a, b, ..., c) to (a, b, ..., c, d) + # where d is the length of the node values + for node in A.children: + N = len(node.values) + print(N) + meta = { + k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,)) + for k, v in pushdown_metadata_A.items() + } + node = node.replace(metadata=node.metadata | meta) + nodes_by_key[node.key][0].append(node) + + for node in B.children: + N = len(node.values) + meta = { + k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,)) + for k, v in pushdown_metadata_B.items() + } + node = node.replace(metadata=node.metadata | meta) + nodes_by_key[node.key][1].append(node) + # For every node group, perform the set operation for key, (A_nodes, B_nodes) in nodes_by_key.items(): - output = list(_operation(key, A_nodes, B_nodes, operation_type, node_type)) + output = list( + _operation(key, A_nodes, B_nodes, operation_type, node_type, depth + 1) + ) + # print(f"{' '*depth}_operation {operation_type.name} {A_nodes} {B_nodes} out = [{output}]") new_children.extend(output) - # print(f"operation {operation_type}: {A}, {B} {new_children = }") - # print(f"{A.children = }") - # print(f"{B.children = }") - # print(f"{new_children = }") + # print(f"{' '*depth}operation {operation_type.name} [{A}] [{B}] new_children = [{new_children}]") # If there are now no children as a result of the operation, return nothing. if (A.children or B.children) and not new_children: if A.key == "root": - return A.replace(children=()) + return node_type.make_root(children=()) else: return None @@ -140,20 +183,34 @@ def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube new_children = list(compress_children(new_children)) # The values and key are the same so we just replace the children - return A.replace(children=tuple(sorted(new_children))) + return node_type.make_node( + key=node_key, + values=node_values, + children=new_children, + metadata=stayput_metadata, + is_root=is_root, + ) + + +def get_indices(metadata: dict[str, np.ndarray], indices: list[int] | slice): + return {k: v[..., indices] for k, v in metadata.items()} -# The root node is special so we need a helper method that we can recurse on def _operation( - key: str, A: list[Qube], B: list[Qube], operation_type: SetOperation, node_type + key: str, + A: list[Qube], + B: list[Qube], + operation_type: SetOperation, + node_type, + depth: int, ) -> Iterable[Qube]: keep_just_A, keep_intersection, keep_just_B = operation_type.value - # Iterate over all pairs (node_A, node_B) values = {} for node in A + B: values[node] = ValuesMetadata(node.values, node.metadata) + # Iterate over all pairs (node_A, node_B) for node_a in A: for node_b in B: # Compute A - B, A & B, B - A @@ -171,17 +228,21 @@ def _operation( if intersection.values: new_node_a = node_a.replace( values=intersection.values, - metadata=intersection.metadata, + metadata=get_indices(node_a.metadata, intersection.indices), ) new_node_b = node_b.replace( values=intersection.values, - metadata=intersection.metadata, + metadata=get_indices(node_b.metadata, intersection.indices), ) - # print(f"{node_a = }") - # print(f"{node_b = }") - # print(f"{intersection.values =}") + # print(f"{' '*depth}{node_a = }") + # print(f"{' '*depth}{node_b = }") + # print(f"{' '*depth}{intersection.values =}") result = operation( - new_node_a, new_node_b, operation_type, node_type + new_node_a, + new_node_b, + operation_type, + node_type, + depth=depth + 1, ) if result is not None: yield result @@ -190,20 +251,20 @@ def _operation( if keep_just_A: for node in A: if values[node].values: - yield node_type.make( + yield node_type.make_node( key, children=node.children, values=values[node].values, - metadata=values[node].metadata, + metadata=get_indices(node.metadata, values[node].indices), ) if keep_just_B: for node in B: if values[node].values: - yield node_type.make( + yield node_type.make_node( key, children=node.children, values=values[node].values, - metadata=values[node].metadata, + metadata=get_indices(node.metadata, values[node].indices), ) @@ -230,7 +291,7 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: key = child_list[0].key # Compress the children into a single node - assert all(isinstance(child.data.values, QEnum) for child in child_list), ( + assert all(isinstance(child.values, QEnum) for child in child_list), ( "All children must have QEnum values" ) @@ -241,19 +302,19 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: metadata: frozendict[str, np.ndarray] = frozendict( { - k: np.concatenate(metadata_group, axis=0) + k: np.concatenate(metadata_group, axis=-1) for k, metadata_group in metadata_groups.items() } ) - node_data = NodeData( - key=key, - metadata=metadata, - values=QEnum(set(v for child in child_list for v in child.data.values)), - ) children = [cc for c in child_list for cc in c.children] compressed_children = compress_children(children) - new_child = node_type(data=node_data, children=compressed_children) + new_child = node_type.make_node( + key=key, + metadata=metadata, + values=QEnum(set(v for child in child_list for v in child.values)), + children=compressed_children, + ) else: # If the group is size one just keep it new_child = child_list.pop() diff --git a/src/python/qubed/tree_formatters.py b/src/python/qubed/tree_formatters.py index 49ead5a..cabd7dd 100644 --- a/src/python/qubed/tree_formatters.py +++ b/src/python/qubed/tree_formatters.py @@ -2,9 +2,8 @@ from __future__ import annotations import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING, Callable, Iterable -import numpy as np if TYPE_CHECKING: from .Qube import Qube @@ -71,27 +70,38 @@ def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[st def summarize_node_html( - node: Qube, collapse=False, max_summary_length=50, **kwargs + node: Qube, + collapse=False, + max_summary_length=50, + info: Callable[[Qube], str] | None = None, + **kwargs, ) -> tuple[str, Qube]: """ Extracts a summarized representation of the node while collapsing single-child paths. Returns the summary string and the last node in the chain that has multiple children. """ + if info is None: + + def info_func(node: Qube, /): + return ( + # f"dtype: {node.dtype}\n" + f"metadata: {dict(node.metadata)}\n" + ) + else: + info_func = info + summaries = [] while True: path = node.summary(**kwargs) summary = path - if "is_leaf" in node.metadata and node.metadata["is_leaf"]: - summary += " 🌿" if len(summary) > max_summary_length: summary = summary[:max_summary_length] + "..." - info = ( - f"dtype: {node.dtype.__name__}\n" - f"metadata: {dict((k, np.shape(v)) for k, v in node.metadata.items())}\n" - ) - summary = f'{summary}' + + info_string = info_func(node) + + summary = f'{summary}' summaries.append(summary) if not collapse: break @@ -105,9 +115,14 @@ def summarize_node_html( def _node_tree_to_html( - node: Qube, prefix: str = "", depth=1, connector="", **kwargs + node: Qube, + prefix: str = "", + depth=1, + connector="", + info: Callable[[Qube], str] | None = None, + **kwargs, ) -> Iterable[str]: - summary, node = summarize_node_html(node, **kwargs) + summary, node = summarize_node_html(node, info=info, **kwargs) if len(node.children) == 0: yield f'{connector}{summary}' @@ -124,13 +139,20 @@ def _node_tree_to_html( prefix + extension, depth=depth - 1, connector=prefix + connector, + info=info, **kwargs, ) yield "" def node_tree_to_html( - node: Qube, depth=1, include_css=True, include_js=True, css_id=None, **kwargs + node: Qube, + depth=1, + include_css=True, + include_js=True, + css_id=None, + info: Callable[[Qube], str] | None = None, + **kwargs, ) -> str: if css_id is None: css_id = f"qubed-tree-{random.randint(0, 1000000)}" @@ -215,5 +237,5 @@ def node_tree_to_html( nodes.forEach(n => n.addEventListener("click", nodeOnClick)); """.replace("CSS_ID", css_id) - nodes = "".join(_node_tree_to_html(node=node, depth=depth, **kwargs)) + nodes = "".join(_node_tree_to_html(node=node, depth=depth, info=info, **kwargs)) return f"{js if include_js else ''}{css if include_css else ''}
{nodes}" diff --git a/src/python/qubed/value_types.py b/src/python/qubed/value_types.py index ad38af6..e72f593 100644 --- a/src/python/qubed/value_types.py +++ b/src/python/qubed/value_types.py @@ -2,7 +2,7 @@ from __future__ import annotations import dataclasses from abc import ABC, abstractmethod -from dataclasses import dataclass, replace +from dataclasses import dataclass from datetime import date, datetime, timedelta from typing import ( TYPE_CHECKING, @@ -21,6 +21,11 @@ if TYPE_CHECKING: @dataclass(frozen=True) class ValueGroup(ABC): + @abstractmethod + def dtype(self) -> str: + "Provide a string rep of the datatype of these values" + pass + @abstractmethod def summary(self) -> str: "Provide a string summary of the value group." @@ -69,9 +74,13 @@ class QEnum(ValueGroup): """ values: EnumValuesType + _dtype: str = "str" def __init__(self, obj): object.__setattr__(self, "values", tuple(sorted(obj))) + object.__setattr__( + self, "dtype", type(self.values[0]) if len(self.values) > 0 else "str" + ) def __post_init__(self): assert isinstance(self.values, tuple) @@ -88,6 +97,9 @@ class QEnum(ValueGroup): def __contains__(self, value: Any) -> bool: return value in self.values + def dtype(self): + return self._dtype + @classmethod def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]: return [cls(tuple(values))] @@ -114,7 +126,7 @@ class WildcardGroup(ValueGroup): return "*" def __len__(self): - return None + return 1 def __iter__(self): return ["*"] @@ -122,6 +134,9 @@ class WildcardGroup(ValueGroup): def __bool__(self): return True + def dtype(self): + return "*" + @classmethod def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]: return [WildcardGroup()] @@ -398,7 +413,7 @@ def convert_datatypes(q: "Qube", conversions: dict[str, ValueGroup]) -> "Qube": ) for values_group in data_type.from_strings(q.values): # print(values_group) - yield replace(q, data=replace(q.data, values=values_group)) + yield q.replace(values=values_group) else: yield q diff --git a/src/qube.proto b/src/qube.proto new file mode 100644 index 0000000..bb2c685 --- /dev/null +++ b/src/qube.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +message NdArray { + repeated int64 shape = 1; + string dtype = 2; + bytes raw = 3; +} + +message StringGroup {repeated string items = 1; } + +// Stores values i.e class=1/2/3 the 1/2/3 part +message ValueGroup { + oneof payload { + StringGroup s = 1; + NdArray tensor = 2; + } +} + +message MetadataGroup { + oneof payload { + NdArray tensor = 1; + } +} + +message Qube { + string key = 1; + ValueGroup values = 2; + map