From 165bf5aca2fdc6b4218870da5d870087ae2f9bf9 Mon Sep 17 00:00:00 2001 From: Tom Date: Tue, 3 Jun 2025 14:57:27 +0200 Subject: [PATCH] Tests passing checkpoint --- src/python/qubed/Qube.py | 24 +- src/python/qubed/set_operations.py | 429 +++++++++++++++++++---------- src/rust/compressed_tree.rs | 334 ---------------------- src/rust/lib.rs | 4 +- src/rust/python_interface.rs | 6 +- tests/test_metadata.py | 60 +++- tests/test_set_operations.py | 86 ++++-- tests/test_wildcard.py | 19 +- 8 files changed, 426 insertions(+), 536 deletions(-) delete mode 100644 src/rust/compressed_tree.rs diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index 1994840..dc291df 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass, field from functools import cached_property from pathlib import Path -from typing import Any, Iterable, Iterator, Literal, Self, Sequence +from typing import Any, Iterable, Iterator, Literal, Mapping, Self, Sequence import numpy as np from frozendict import frozendict @@ -67,7 +67,7 @@ class QubeNamedRoot: return self.key -@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True) +@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True) class Qube: key: str values: ValueGroup @@ -77,6 +77,8 @@ class Qube: children: tuple[Qube, ...] = () is_root: bool = False is_leaf: bool = False + depth: int = field(default=0, compare=False) + shape: tuple[int, ...] = field(default=(), compare=False) @classmethod def make_node( @@ -84,16 +86,20 @@ class Qube: key: str, values: Iterable | QEnum | WildcardGroup, children: Iterable[Qube], - metadata: dict[str, np.ndarray] = {}, + metadata: Mapping[str, np.ndarray] = {}, is_root: bool = False, is_leaf: bool | None = None, ) -> Qube: - children = tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))) if isinstance(values, ValueGroup): values = values else: values = QEnum(values) + if not isinstance(values, WildcardGroup) and not is_root: + assert len(values) > 0, "Nodes must have at least one value" + + children = tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))) + return cls( key, values=values, @@ -105,6 +111,14 @@ class Qube: @classmethod def make_root(cls, children: Iterable[Qube], metadata={}) -> Qube: + def update_depth_shape(children, depth, shape): + for child in children: + child.depth = depth + 1 + child.shape = shape + (len(child.values),) + update_depth_shape(child.children, child.depth, child.shape) + + update_depth_shape(children, depth=0, shape=(1,)) + return cls.make_node( "root", values=QEnum(("root",)), @@ -127,7 +141,7 @@ class Qube: return Qube.from_json(json.load(f)) @classmethod - def from_datacube(cls, datacube: dict[str, str | Sequence[str]]) -> Qube: + def from_datacube(cls, datacube: Mapping[str, str | Sequence[str]]) -> Qube: key_vals = list(datacube.items())[::-1] children: list[Qube] = [] diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index bbc7b8d..e3d2d29 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -1,3 +1,25 @@ +""" +# Set Operations + +The core of this is the observation that for two sets A and B, if we compute (A - B), (A ∩ B) amd (B - A) +then we can get the other operations by taking unions of the above three objects. +Union: All of them +Intersection: Just take A ∩ B +Difference: Take either A - B or B - A +Symmetric Difference (XOR): Take A - B and B - A + +We start with a shallow implementation of this algorithm that only deals with a pair of nodes, not the whole tree: + +shallow_set_operation(A: Qube, B: Qube) -> SetOpsResult + +This takes two qubes and (morally) returns (A - B), (A ∩ B) amd (B - A) but only for the values and metadata at the top level. + +For technical reasons that will become clear we actually return a struct with two copies of (A ∩ B). One has the metadata from A and the children of A call it A', and the other has them from B call it B'. This is relevant when we extend the shallow algorithm to work with a whole tree because we will recurse and compute the set operation for each pair of the children of A' and B'. + +NB: Currently there are two kinds of values, QEnums, that store a list of values and Wildcards that 'match with everything'. shallow_set_operation checks the type of values and dispatches to different methods depending on the combination of types it finds. + +""" + from __future__ import annotations from collections import defaultdict @@ -17,6 +39,8 @@ if TYPE_CHECKING: class SetOperation(Enum): + "Map from set operations to which combination of (A - B), (A ∩ B), (B - A) we need." + UNION = (1, 1, 1) INTERSECTION = (0, 1, 0) DIFFERENCE = (1, 0, 0) @@ -24,78 +48,145 @@ class SetOperation(Enum): @dataclass(eq=True, frozen=True) -class ValuesMetadata: +class ValuesIndices: + "Helper class to hold the values and indices from a node." + values: ValueGroup - indices: list[int] | slice + indices: tuple[int, ...] + + @classmethod + def from_values(cls, values: ValueGroup): + return cls(values=values, indices=tuple(range(len(values)))) + + @classmethod + def empty(cls): + return cls(values=QEnum([]), indices=()) + + def enumerate(self) -> Iterable[tuple[Any, int]]: + return zip(self.indices, self.values) -def QEnum_intersection( - A: ValuesMetadata, - B: ValuesMetadata, -) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]: - intersection: dict[Any, int] = {} - just_A: dict[Any, int] = {} - just_B: dict[Any, int] = {val: i for i, val in enumerate(B.values)} - - for index_a, val_A in enumerate(A.values): - if val_A in B.values: - just_B.pop(val_A) - intersection[val_A] = ( - index_a # We throw away any overlapping metadata from B - ) - else: - just_A[val_A] = index_a - - intersection_out = ValuesMetadata( - values=QEnum(list(intersection.keys())), - indices=list(intersection.values()), +def get_indices( + metadata: frozendict[str, np.ndarray], indices: tuple[int, ...] +) -> frozendict[str, np.ndarray]: + "Given a metadata dict and some indices, return a new metadata dict with only the values indexed by the indices" + return frozendict( + {k: v[..., indices] for k, v in metadata.items() if isinstance(v, np.ndarray)} ) - just_A_out = ValuesMetadata( - values=QEnum(list(just_A.keys())), - indices=list(just_A.values()), - ) - just_B_out = ValuesMetadata( - values=QEnum(list(just_B.keys())), - indices=list(just_B.values()), - ) +@dataclass(eq=True, frozen=True) +class SetOpResult: + """ + Given two sets A and B, all possible set operations can be constructed from A - B, A ∩ B, B - A + That is, what's only in A, the intersection and what's only in B + However because we need to recurse on children we actually return two intersection node: + only_A is a qube with: + The values in A but not in B + The metadata corresponding to this values + All the children A had - return just_A_out, intersection_out, just_B_out + intersection_A is a qube with: + The values that intersected with B + The metadata from that intersection + All the children A had + + And vice versa for only_B and intersection B + """ + + only_A: ValuesIndices + intersection_A: ValuesIndices + intersection_B: ValuesIndices + only_B: ValuesIndices -def node_intersection( - A: ValuesMetadata, - B: ValuesMetadata, -) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]: - if isinstance(A.values, QEnum) and isinstance(B.values, QEnum): - return QEnum_intersection(A, B) +def shallow_qenum_set_operation(A: ValuesIndices, B: ValuesIndices) -> SetOpResult: + """ + For two sets of values, partition the overlap into four groups: + only_A: values and indices of values that are in A but not B + intersection_A: values and indices of values that are in both A and B + And vice versa for only_B and intersection_B. - if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup): - return ( - ValuesMetadata(QEnum([]), []), - ValuesMetadata(WildcardGroup(), slice(None)), - ValuesMetadata(QEnum([]), []), + Note that intersection_A and intersection_B contain the same values but the indices are different. + """ + + # create four groups that map value -> index + only_A: dict[Any, int] = {val: i for i, val in A.enumerate()} + only_B: dict[Any, int] = {val: i for i, val in B.enumerate()} + intersection_A: dict[Any, int] = {} + intersection_B: dict[Any, int] = {} + + # Go through all the values and move any that are in the intersection + # to the corresponding group, keeping the indices + for val in A.values: + if val in B.values: + intersection_A[val] = only_A.pop(val) + intersection_B[val] = only_B.pop(val) + + def package(values_indices: dict[Any, int]) -> ValuesIndices: + return ValuesIndices( + values=QEnum(list(values_indices.keys())), + indices=tuple(values_indices.values()), ) - # If A is a wildcard matcher then the intersection is everything - # just_A is still * - # just_B is empty - if isinstance(A.values, WildcardGroup): - return A, B, ValuesMetadata(QEnum([]), []) + return SetOpResult( + only_A=package(only_A), + only_B=package(only_B), + intersection_A=package(intersection_A), + intersection_B=package(intersection_B), + ) - # The reverse if B is a wildcard + +def shallow_wildcard_set_operation(A: ValuesIndices, B: ValuesIndices) -> SetOpResult: + """ + WildcardGroups behave as if they contain all the values of whatever they match against. + For two wildcards we just return both. + For A == wildcard and B == enum we have to be more careful: + 1. All of B is in the intersection so only_B is None too. + 2. The wildcard may need to match against other things so only_A is A + 3. We return B in the intersection_B and intersection_A slot. + + This last bit happens because the wildcard basically adopts the values of whatever it sees. + """ + # Two wildcard groups have full overlap. + if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup): + return SetOpResult(ValuesIndices.empty(), A, B, ValuesIndices.empty()) + + # If A is a wildcard matcher and B is not + # then the intersection is everything from B + if isinstance(A.values, WildcardGroup): + return SetOpResult(A, B, B, ValuesIndices.empty()) + + # If B is a wildcard matcher and A is not + # then the intersection is everything from A if isinstance(B.values, WildcardGroup): - return ValuesMetadata(QEnum([]), []), A, B + return SetOpResult(ValuesIndices.empty(), A, A, B) raise NotImplementedError( - f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented" + f"One of {type(A.values)} and {type(B.values)} should be WildCardGroup" + ) + + +def shallow_set_operation( + A: ValuesIndices, + B: ValuesIndices, +) -> SetOpResult: + if isinstance(A.values, QEnum) and isinstance(B.values, QEnum): + return shallow_qenum_set_operation(A, B) + + # WildcardGroups behave as if they contain all possible values. + if isinstance(A.values, WildcardGroup) or isinstance(B.values, WildcardGroup): + return shallow_wildcard_set_operation(A, B) + + raise NotImplementedError( + f"Set operations on values types {type(A.values)} and {type(B.values)} not yet implemented" ) def operation( A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0 ) -> Qube | None: + # print(f"operation({A}, {B})") assert A.key == B.key, ( "The two Qube root nodes must have the same key to perform set operations," f"would usually be two root nodes. They have {A.key} and {B.key} respectively" @@ -122,14 +213,16 @@ def operation( pushdown_metadata_B: dict[str, np.ndarray] = {} for key in set(A.metadata.keys()) | set(B.metadata.keys()): if key not in A.metadata: - raise ValueError(f"B has key {key} but A does not. {A = } {B = }") - if key not in B.metadata: - raise ValueError(f"A has key {key} but B does not. {A = } {B = }") + pushdown_metadata_B[key] = B.metadata[key] + continue + + if key not in B.metadata: + pushdown_metadata_A[key] = A.metadata[key] + continue - # print(f"{key = } {A.metadata[key] = } {B.metadata[key]}") A_val = A.metadata[key] B_val = B.metadata[key] - if A_val == B_val: + if np.allclose(A_val, B_val): # print(f"{' ' * depth}Keeping metadata key '{key}' at this level") stayput_metadata[key] = A.metadata[key] else: @@ -143,7 +236,6 @@ def operation( # where d is the length of the node values for node in A.children: N = len(node.values) - # print(N) meta = { k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,)) for k, v in pushdown_metadata_A.items() @@ -160,10 +252,12 @@ def operation( node = node.replace(metadata=node.metadata | meta) nodes_by_key[node.key][1].append(node) + # print(f"{nodes_by_key = }") + # For every node group, perform the set operation for key, (A_nodes, B_nodes) in nodes_by_key.items(): output = list( - _operation(key, A_nodes, B_nodes, operation_type, node_type, depth + 1) + _operation(A_nodes, B_nodes, operation_type, node_type, depth + 1) ) # print(f"{' '*depth}_operation {operation_type.name} {A_nodes} {B_nodes} out = [{output}]") new_children.extend(output) @@ -183,6 +277,11 @@ def operation( new_children = list(compress_children(new_children)) # The values and key are the same so we just replace the children + if A.key == "root": + return node_type.make_root( + children=new_children, + metadata=stayput_metadata, + ) return node_type.make_node( key=node_key, values=node_values, @@ -192,85 +291,86 @@ def operation( ) -def get_indices(metadata: dict[str, np.ndarray], indices: list[int] | slice): - return { - k: v[..., indices] for k, v in metadata.items() if isinstance(v, np.ndarray) - } - - def _operation( - key: str, A: list[Qube], B: list[Qube], operation_type: SetOperation, node_type, depth: int, ) -> Iterable[Qube]: - keep_just_A, keep_intersection, keep_just_B = operation_type.value + """ + This operation assumes that we've found two nodes that match and now want to do a set operation on their children. Hence we take in two lists of child nodes all of which have the same key but different values. + We then loop over all pairs of children from each list and compute the intersection. + """ + # print(f"_operation({A}, {B})") + keep_only_A, keep_intersection, keep_only_B = operation_type.value - values = {} - for node in A + B: - values[node] = ValuesMetadata(node.values, node.metadata) + # We're going to progressively remove values from the starting nodes as we do intersections + # So we make a node -> ValuesIndices mapping here for both a and b + only_a: dict[Qube, ValuesIndices] = { + n: ValuesIndices.from_values(n.values) for n in A + } + only_b: dict[Qube, ValuesIndices] = { + n: ValuesIndices.from_values(n.values) for n in B + } - # Iterate over all pairs (node_A, node_B) + def make_new_node(source: Qube, values_indices: ValuesIndices): + return source.replace( + values=values_indices.values, + metadata=get_indices(source.metadata, values_indices.indices), + ) + + # Iterate over all pairs (node_A, node_B) and perform the shallow set operation + # Update our copy of the original node to remove anything that appears in an intersection for node_a in A: for node_b in B: - # Compute A - B, A & B, B - A - # Update the values for the two source nodes to remove the intersection - just_a, intersection, just_b = node_intersection( - values[node_a], - values[node_b], - ) + set_ops_result = shallow_set_operation(only_a[node_a], only_b[node_b]) - # Remove the intersection from the source nodes - values[node_a] = just_a - values[node_b] = just_b + # Save reduced values back to nodes + only_a[node_a] = set_ops_result.only_A + only_b[node_b] = set_ops_result.only_B - if keep_intersection: - if intersection.values: - new_node_a = node_a.replace( - values=intersection.values, - metadata=get_indices(node_a.metadata, intersection.indices), - ) - new_node_b = node_b.replace( - values=intersection.values, - metadata=get_indices(node_b.metadata, intersection.indices), - ) - # print(f"{' '*depth}{node_a = }") - # print(f"{' '*depth}{node_b = }") - # print(f"{' '*depth}{intersection.values =}") - result = operation( - new_node_a, - new_node_b, - operation_type, - node_type, - depth=depth + 1, - ) - if result is not None: + if ( + set_ops_result.intersection_A.values + and set_ops_result.intersection_B.values + ): + result = operation( + make_new_node(node_a, set_ops_result.intersection_A), + make_new_node(node_b, set_ops_result.intersection_B), + operation_type, + node_type, + depth=depth + 1, + ) + if result is not None: + # If we're doing a difference or xor we might want to throw away the intersection + # However we can only do this once we get to the leaf nodes, otherwise we'll + # throw away nodes too early! + # Consider Qube(root, a=1, b=1/2) - Qube(root, a=1, b=1) + # We can easily throw away the whole a node by accident here! + if keep_intersection or result.children: yield result - - # Now we've removed all the intersections we can yield the just_A and just_B parts if needed - if keep_just_A: - for node in A: - if values[node].values: - yield node_type.make_node( - key, - children=node.children, - values=values[node].values, - metadata=get_indices(node.metadata, values[node].indices), - ) - if keep_just_B: - for node in B: - if values[node].values: - yield node_type.make_node( - key, - children=node.children, - values=values[node].values, - metadata=get_indices(node.metadata, values[node].indices), + elif ( + not set_ops_result.intersection_A.values + and not set_ops_result.intersection_B.values + ): + continue + else: + raise ValueError( + f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result = }" ) + if keep_only_A: + for node, vi in only_a.items(): + if vi.values: + yield make_new_node(node, vi) -def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: + if keep_only_B: + for node, vi in only_b.items(): + if vi.values: + yield make_new_node(node, vi) + + +def compress_children(children: Iterable[Qube], depth=0) -> tuple[Qube, ...]: """ Helper method tht only compresses a set of nodes, and doesn't do it recursively. Used in Qubed.compress but also to maintain compression in the set operations above. @@ -287,49 +387,78 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: # Now go through and create new compressed nodes for any groups that need collapsing new_children = [] for child_list in identical_children.values(): - if len(child_list) > 1: + # If the group is size one just keep it + if len(child_list) == 1: + new_child = child_list.pop() + + else: example = child_list[0] node_type = type(example) - key = child_list[0].key + value_type = type(example.values) - # Compress the children into a single node - assert all(isinstance(child.values, QEnum) for child in child_list), ( - "All children must have QEnum values" + assert all(isinstance(child.values, value_type) for child in child_list), ( + f"All nodes to be grouped must have the same value type, expected {value_type}" ) - metadata_groups = { - k: [child.metadata[k] for child in child_list] - for k in example.metadata.keys() - } + # We know the children of this group of nodes all have the same structure + # but we still need to merge the metadata across them + # children = example.children + children = merge_metadata(child_list, example.depth) - metadata: frozendict[str, np.ndarray] = frozendict( - { - k: np.concatenate(metadata_group, axis=-1) - for k, metadata_group in metadata_groups.items() - } - ) + # Do we need to recusively compress here? + # children = compress_children(children, depth=depth+1) + + if value_type is QEnum: + values = QEnum(set(v for child in child_list for v in child.values)) + elif value_type is WildcardGroup: + values = example.values + else: + raise ValueError(f"Unknown value type: {value_type}") - children = [cc for c in child_list for cc in c.children] - compressed_children = compress_children(children) new_child = node_type.make_node( - key=key, - metadata=metadata, - values=QEnum(set(v for child in child_list for v in child.values)), - children=compressed_children, + key=example.key, + metadata=example.metadata, + values=values, + children=children, ) - else: - # If the group is size one just keep it - new_child = child_list.pop() new_children.append(new_child) return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min())))) -def union(a: Qube, b: Qube) -> Qube: - return operation( - a, - b, - SetOperation.UNION, - type(a), - ) +def merge_metadata(qubes: list[Qube], axis) -> Iterable[Qube]: + """ + Given a list of qubes with identical structure, + match up the children of each node and merge the metadata + """ + # Group the children of each qube and merge them + # Exploit the fact that they have the same shape and ordering + example = qubes[0] + node_type = type(example) + + for i in range(len(example.children)): + group = [q.children[i] for q in qubes] + group_example = group[0] + assert len(set((c.structural_hash for c in group))) == 1 + + # Collect metadata by key + metadata_groups = { + k: [q.metadata[k] for q in group] for k in group_example.metadata.keys() + } + + # Concatenate the metadata together + metadata: frozendict[str, np.ndarray] = frozendict( + { + k: np.concatenate(metadata_group, axis=axis) + for k, metadata_group in metadata_groups.items() + } + ) + + group_children = merge_metadata(group, axis) + yield node_type.make_node( + key=group_example.key, + metadata=metadata, + values=group_example.values, + children=group_children, + ) diff --git a/src/rust/compressed_tree.rs b/src/rust/compressed_tree.rs deleted file mode 100644 index 9148d54..0000000 --- a/src/rust/compressed_tree.rs +++ /dev/null @@ -1,334 +0,0 @@ -#![allow(dead_code)] - -use std::rc::Rc; -use smallstr::SmallString; - -use slotmap::{new_key_type, SlotMap}; - -new_key_type! { - struct NodeId; -} - -type CompactString = SmallString<[u8; 16]>; - -#[derive(Clone)] -enum NodeValueTypes { - String(CompactString), - Int(i32), -} - -impl From<&str> for NodeValueTypes { - fn from(s: &str) -> Self { - NodeValueTypes::String(CompactString::from(s)) - } -} - -impl From for NodeValueTypes { - fn from(i: i32) -> Self { - NodeValueTypes::Int(i) - } -} - -enum NodeValue { - Single(NodeValueTypes), - Multiple(Vec), -} - -struct Node { - key: Rc, - value: NodeValue, - parent: Option, - prev_sibling: Option, - next_sibling: Option, - // vector may be faster for traversal, but linkedlist should be faster for insertion - children: Option<(NodeId, NodeId)>, // (first_child, last_child) - data: Option, -} - -struct QueryTree { - nodes: SlotMap>, -} - -impl QueryTree { - fn new() -> Self { - QueryTree { - nodes: SlotMap::with_key(), - } - } - - // Adds a node with a key and single value - fn add_node(&mut self, key: &Rc, value: S, parent: Option) -> NodeId - where - S: Into, - { - let node_id = self.nodes.insert_with_key(|_| Node { - key: Rc::clone(key), - value: NodeValue::Single(value.into()), - parent, - prev_sibling: None, - next_sibling: None, - children: None, - data: None, - }); - - if let Some(parent_id) = parent { - // Determine if parent has existing children - if let Some((first_child_id, last_child_id)) = self.nodes[parent_id].children { - // Update the last child's `next_sibling` - { - let last_child = &mut self.nodes[last_child_id]; - last_child.next_sibling = Some(node_id); - } - - // Update the new node's `prev_sibling` - { - let new_node = &mut self.nodes[node_id]; - new_node.prev_sibling = Some(last_child_id); - } - - // Update parent's last child - let parent_node = &mut self.nodes[parent_id]; - parent_node.children = Some((first_child_id, node_id)); - } else { - // No existing children - let parent_node = &mut self.nodes[parent_id]; - parent_node.children = Some((node_id, node_id)); - } - } - - node_id - } - - // Add a single value to a node - fn add_value(&mut self, node_id: NodeId, value: S) - where - S: Into, - { - if let Some(node) = self.nodes.get_mut(node_id) { - match &mut node.value { - NodeValue::Single(v) => { - let values = vec![v.clone(), value.into()]; - node.value = NodeValue::Multiple(values); - } - NodeValue::Multiple(values) => { - values.push(value.into()); - } - } - } - } - - // Add multiple values to a node - fn add_values(&mut self, node_id: NodeId, values: Vec) - where - S: Into, - { - if let Some(node) = self.nodes.get_mut(node_id) { - match &mut node.value { - NodeValue::Single(v) => { - let mut new_values = vec![v.clone()]; - new_values.extend(values.into_iter().map(|v| v.into())); - node.value = NodeValue::Multiple(new_values); - } - NodeValue::Multiple(existing_values) => { - existing_values.extend(values.into_iter().map(|v| v.into())); - } - } - } - } - - fn get_node(&self, node_id: NodeId) -> Option<&Node> { - self.nodes.get(node_id) - } - - // TODO: better if this returns an iterator? - fn get_children(&self, node_id: NodeId) -> Vec { - let mut children = Vec::new(); - - if let Some(node) = self.get_node(node_id) { - if let Some((first_child_id, _)) = node.children { - let mut current_id = Some(first_child_id); - while let Some(cid) = current_id { - children.push(cid); - current_id = self.nodes[cid].next_sibling; - } - } - } - - children - } - - fn remove_node(&mut self, node_id: NodeId) { - // Remove the node and update parent and siblings - if let Some(node) = self.nodes.remove(node_id) { - // Update parent's children - if let Some(parent_id) = node.parent { - let parent_node = self.nodes.get_mut(parent_id).unwrap(); - if let Some((first_child_id, last_child_id)) = parent_node.children { - if first_child_id == node_id && last_child_id == node_id { - // Node was the only child - parent_node.children = None; - } else if first_child_id == node_id { - // Node was the first child - parent_node.children = Some((node.next_sibling.unwrap(), last_child_id)); - } else if last_child_id == node_id { - // Node was the last child - parent_node.children = Some((first_child_id, node.prev_sibling.unwrap())); - } - } - } - - // Update siblings - if let Some(prev_id) = node.prev_sibling { - self.nodes[prev_id].next_sibling = node.next_sibling; - } - if let Some(next_id) = node.next_sibling { - self.nodes[next_id].prev_sibling = node.prev_sibling; - } - - // Recursively remove children - let children_ids = self.get_children(node_id); - for child_id in children_ids { - self.remove_node(child_id); - } - } - } - - fn is_root(&self, node_id: NodeId) -> bool { - self.nodes[node_id].parent.is_none() - } - - fn is_leaf(&self, node_id: NodeId) -> bool { - self.nodes[node_id].children.is_none() - } - - fn add_payload(&mut self, node_id: NodeId, payload: Payload) { - if let Some(node) = self.nodes.get_mut(node_id) { - node.data = Some(payload); - } - } - - fn print_tree(&self) { - // Find all root nodes (nodes without a parent) - let roots: Vec = self - .nodes - .iter() - .filter_map(|(id, node)| { - if node.parent.is_none() { - Some(id) - } else { - None - } - }) - .collect(); - - // Iterate through each root node and print its subtree - for (i, root_id) in roots.iter().enumerate() { - let is_last = i == roots.len() - 1; - self.print_node(*root_id, String::new(), is_last); - } - } - - /// Recursively prints a node and its children. - /// - /// - `node_id`: The current node's ID. - /// - `prefix`: The string prefix for indentation and branch lines. - /// - `is_last`: Boolean indicating if the node is the last child of its parent. - fn print_node(&self, node_id: NodeId, prefix: String, is_last: bool) { - // Retrieve the current node - let node = match self.nodes.get(node_id) { - Some(n) => n, - None => return, // Node not found; skip - }; - - // Determine the branch character - let branch = if prefix.is_empty() { - "" // Root node doesn't have a branch - } else if is_last { - "└── " // Last child - } else { - "├── " // Middle child - }; - - // Print the current node's key and values - print!("{}{}{}", prefix, branch, node.key); - match &node.value { - NodeValue::Single(v) => match v { - NodeValueTypes::String(s) => println!(": ({})", s), - NodeValueTypes::Int(i) => println!(": ({})", i), - }, - NodeValue::Multiple(vs) => { - let values: Vec = vs - .iter() - .map(|v| match v { - NodeValueTypes::String(s) => s.to_string(), - NodeValueTypes::Int(i) => i.to_string(), - }) - .collect(); - println!(": ({})", values.join(", ")); - } - } - - // Prepare the prefix for child nodes - let new_prefix = if prefix.is_empty() { - if is_last { - " ".to_string() - } else { - "│ ".to_string() - } - } else { - if is_last { - format!("{} ", prefix) - } else { - format!("{}│ ", prefix) - } - }; - - // Retrieve and iterate through child nodes - if let Some((_first_child_id, _last_child_id)) = node.children { - let children = self.get_children(node_id); - let total = children.len(); - for (i, child_id) in children.iter().enumerate() { - let child_is_last = i == total - 1; - self.print_node(*child_id, new_prefix.clone(), child_is_last); - } - } - } -} - -fn main() { - let mut tree: QueryTree = QueryTree::new(); - - let value = "hello"; - let axis = Rc::new("foo".to_string()); - - let root_id = tree.add_node(&axis, value, None); - - use std::time::Instant; - let now = Instant::now(); - - for _ in 0..100 { - // let child_value = format!("child_val{}", i); - let child_id = tree.add_node(&axis, value, Some(root_id)); - // tree.add_value(child_id, value); - - for _ in 0..100 { - // let gchild_value = format!("gchild_val{}", j); - let gchild_id = tree.add_node(&axis, value, Some(child_id)); - // tree.add_values(gchild_id, vec![1, 2]); - - for _ in 0..1000 { - // let ggchild_value = format!("ggchild_val{}", k); - let _ggchild_id = tree.add_node(&axis, value, Some(gchild_id)); - // tree.add_value(_ggchild_id, value); - // tree.add_values(_ggchild_id, vec![1, 2, 3, 4]); - } - } - } - - assert_eq!(tree.nodes.len(), 10_010_101); - - let elapsed = now.elapsed(); - println!("Elapsed: {:.2?}", elapsed); - - // tree.print_tree(); -} diff --git a/src/rust/lib.rs b/src/rust/lib.rs index 2411e30..9b824ca 100644 --- a/src/rust/lib.rs +++ b/src/rust/lib.rs @@ -147,9 +147,9 @@ impl Qube { StringId(self.strings.get_or_intern(val)) } - pub(crate) fn add_node(&mut self, parent: NodeId, key: &str, values: &[&str]) -> NodeId { + pub(crate) fn add_node(&mut self, parent: NodeId, key: &str, values: impl IntoIterator>) -> NodeId { let key_id = self.get_or_intern(key); - let values = values.iter().map(|val| self.get_or_intern(val)).collect(); + let values = values.into_iter().map(|val| self.get_or_intern(val.as_ref())).collect(); // Create the node object let node = Node { diff --git a/src/rust/python_interface.rs b/src/rust/python_interface.rs index fc6519f..bf2425b 100644 --- a/src/rust/python_interface.rs +++ b/src/rust/python_interface.rs @@ -78,6 +78,8 @@ pub enum OneOrMany { Many(Vec), } +// Todo: Is there a way to rewrite this so that is doesn't allocate? +// Perhaps by returning an iterator? impl Into> for OneOrMany { fn into(self) -> Vec { match self { @@ -108,10 +110,8 @@ impl Qube { // massage values from T | Vec into Vec let values: Vec = values.into(); - let values_refs: Vec<&str> = values.iter().map(String::as_str).collect(); - let mut q = slf.borrow_mut(); - let node_id = q.add_node(parent.id, key, &values_refs); + let node_id = q.add_node(parent.id, key, &values); Ok(PyNodeRef { id: node_id, qube: slf.into()}) } diff --git a/tests/test_metadata.py b/tests/test_metadata.py index ceaedf9..cce8b7c 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,8 +1,33 @@ -from frozendict import frozendict +# from frozendict import frozendict +# from qubed import Qube -def make_set(entries): - return set((frozendict(a), frozendict(b)) for a, b in entries) +# def make_set(entries): +# return set((frozendict(a), frozendict(b)) for a, b in entries) + +# def construction(): +# q = Qube.from_nodes( +# { +# "class": dict(values=["od", "rd"]), +# "expver": dict(values=[1, 2]), +# "stream": dict( +# values=["a", "b", "c"], metadata=dict(number=list(range(12))) +# ), +# } +# ) +# assert make_set(q.leaves_with_metadata()) == make_set([ +# ({'class': 'od', 'expver': 1, 'stream': 'a'}, {'number': 0}), +# ({'class': 'od', 'expver': 1, 'stream': 'b'}, {'number': 1}), +# ({'class': 'od', 'expver': 1, 'stream': 'c'}, {'number': 2}), +# ({'class': 'od', 'expver': 2, 'stream': 'a'}, {'number': 3}), +# ({'class': 'od', 'expver': 2, 'stream': 'b'}, {'number': 4}), +# ({'class': 'od', 'expver': 2, 'stream': 'c'}, {'number': 5}), +# ({'class': 'rd', 'expver': 1, 'stream': 'a'}, {'number': 6}), +# ({'class': 'rd', 'expver': 1, 'stream': 'b'}, {'number': 7}), +# ({'class': 'rd', 'expver': 1, 'stream': 'c'}, {'number': 8}), +# ({'class': 'rd', 'expver': 2, 'stream': 'a'}, {'number': 9}), +# ({'class': 'rd', 'expver': 2, 'stream': 'b'}, {'number': 10}), +# ({'class': 'rd', 'expver': 2, 'stream': 'c'}, {'number': 11})]) # def test_simple_union(): @@ -42,3 +67,32 @@ def make_set(entries): # assert make_set(expected_union.leaves_with_metadata()) == make_set( # union.leaves_with_metadata() # ) + + +# def test_construction_from_fdb(): +# import json +# paths = {} +# current_path = None +# i = 0 + +# qube = Qube.empty() +# with open("tests/data/climate_dt_paths.json") as f: +# for l in f.readlines(): +# i += 1 +# j = json.loads(l) +# if "type" in j and j["type"] == "path": +# paths[j["i"]] = j["path"] + +# else: +# request = j.pop("keys") +# metadata = j +# # print(request, metadata) + +# q = Qube.from_nodes({ +# key : dict(values = [value]) +# for key, value in request.items() +# }).add_metadata(**metadata) + +# qube = qube | q + +# if i > 100: break diff --git a/tests/test_set_operations.py b/tests/test_set_operations.py index cbd1ddc..4e48cfc 100644 --- a/tests/test_set_operations.py +++ b/tests/test_set_operations.py @@ -1,12 +1,18 @@ from qubed import Qube -def set_operation_testcase(testcase): +def set_operation_testcase(name, testcase): q1 = Qube.from_tree(testcase["q1"]) q2 = Qube.from_tree(testcase["q2"]) - assert q1 | q2 == Qube.from_tree(testcase["union"]) - assert q1 & q2 == Qube.from_tree(testcase["intersection"]) - assert q1 - q2 == Qube.from_tree(testcase["q1 - q2"]) + assert q1 | q2 == Qube.from_tree(testcase["union"]), ( + f"Case: {name} Op: Union\n {q1 = }\n {q2 = }\n {q1 | q2 = }\n expected = {testcase['union']}\n" + ) + assert q1 & q2 == Qube.from_tree(testcase["intersection"]), ( + f"Case: {name} Op: Intersection\n {q1 = }\n {q2 = }\n {q1 - q2 = }\n expected = {testcase['intersection']}\n" + ) + assert q1 - q2 == Qube.from_tree(testcase["difference"]), ( + f"Case: {name} Op: Difference\n {q1 = }\n {q2 = }\n {q1 - q2 = }\n expected = {testcase['difference']}\n" + ) # These are a bunch of testcases where q1 and q2 are specified and then their union/intersection/difference are checked @@ -19,30 +25,27 @@ def set_operation_testcase(testcase): # "q2": str(q2), # "union": str(q1 | q2), # "intersection": str(q1 & q2), -# "q1 - q2": str(q1 - q2), +# "difference": str(q1 - q2), # } # BUT MANUALLY CHECK THE OUTPUT BEFORE ADDING IT AS A TEST CASE! -testcases = [ - # Simplest case, only leaves differ - { +testcases = { + "Simplest case, only leaves differ": { "q1": "root, a=1, b=1, c=1", "q2": "root, a=1, b=1, c=2", "union": "root, a=1, b=1, c=1/2", "intersection": "root", - "q1 - q2": "root", + "difference": "root, a=1, b=1, c=1", }, - # Some overlap but also each tree has unique items - { + "Some overlap but also each tree has unique items": { "q1": "root, a=1, b=1, c=1/2/3", "q2": "root, a=1, b=1, c=2/3/4", "union": "root, a=1, b=1, c=1/2/3/4", "intersection": "root, a=1, b=1, c=2/3", - "q1 - q2": "root", + "difference": "root, a=1, b=1, c=1", }, - # Overlap at two levels - { + "Overlap at two levels": { "q1": "root, a=1, b=1/2, c=1/2/3", "q2": "root, a=1, b=2/3, c=2/3/4", "union": """ @@ -52,26 +55,48 @@ testcases = [ └── b=3, c=2/3/4 """, "intersection": "root, a=1, b=2, c=2/3", - "q1 - q2": "root", + "difference": """ + root, a=1 + ├── b=1, c=1/2/3 + └── b=2, c=1""", }, - # Check that we can merge even if the divergence point is higher - { + "Simple difference": { + "q1": "root, a=1, b=1, c=1/2/3", + "q2": "root, a=1, b=1, c=2", + "union": "root, a=1, b=1, c=1/2/3", + "intersection": "root, a=1, b=1, c=2", + "difference": "root, a=1, b=1, c=1/3", + }, + "Check that we can merge even if the divergence point is higher": { "q1": "root, a=1, b=1, c=1", "q2": "root, a=2, b=1, c=1", "union": "root, a=1/2, b=1, c=1", "intersection": "root", - "q1 - q2": "root, a=1, b=1, c=1", + "difference": "root, a=1, b=1, c=1", }, - # Two equal qubes - { + "Two equal qubes": { "q1": "root, a=1, b=1, c=1", "q2": "root, a=1, b=1, c=1", "union": "root, a=1, b=1, c=1", "intersection": "root, a=1, b=1, c=1", - "q1 - q2": "root", + "difference": "root", }, - # With wildcards - { + "Two qubes that don't compress on their own but the union does": { + "q1": """ + root + ├── a=1/3, b=1 + └── a=2, b=1/2 + """, + "q2": "root, a=1/3, b=2", + "union": "root, a=1/2/3, b=1/2", + "intersection": "root", + "difference": """ + root + ├── a=1/3, b=1 + └── a=2, b=1/2 + """, + }, + "With wildcards": { "q1": "root, frequency=*, levtype=*, param=*, levelist=*, domain=a/b/c/d", "q2": "root, frequency=*, levtype=*, param=*, domain=a/b/c/d", "union": """ @@ -80,14 +105,21 @@ testcases = [ └── levelist=*, domain=a/b/c/d """, "intersection": "root", - "q1 - q2": "root", + "difference": "root, frequency=*, levtype=*, param=*, levelist=*, domain=a/b/c/d", }, -] + "Merging wildcard groups": { + "q1": "root, levtype=pl, param=q, levelist=100/1000, quantile=*", + "q2": "root, levtype=pl, param=t, levelist=100/1000, quantile=*", + "union": "root, levtype=pl, param=q/t, levelist=100/1000, quantile=*", + "intersection": "root", + "difference": "root, levtype=pl, param=q, levelist=100/1000, quantile=*", + }, +} def test_cases(): - for case in testcases: - set_operation_testcase(case) + for name, case in testcases.items(): + set_operation_testcase(name, case) def test_leaf_conservation(): diff --git a/tests/test_wildcard.py b/tests/test_wildcard.py index a867464..ffd6087 100644 --- a/tests/test_wildcard.py +++ b/tests/test_wildcard.py @@ -1,17 +1,12 @@ from qubed import Qube -q = Qube.from_dict( - { - "class=od": { - "expver=0001": {"param=1": {}, "param=2": {}}, - "expver=0002": {"param=1": {}, "param=2": {}}, - }, - "class=rd": { - "expver=0001": {"param=1": {}, "param=2": {}, "param=3": {}}, - "expver=0002": {"param=1": {}, "param=2": {}}, - }, - } -) +q = Qube.from_tree(""" +root +├── class=od, expver=0001/0002, param=1/2 +└── class=rd + ├── expver=0001, param=1/2/3 + └── expver=0002, param=1/2 +""") wild_datacube = { "class": "*",