Tests passing checkpoint

This commit is contained in:
Tom 2025-06-03 14:57:27 +02:00
parent aaafa28dfb
commit 165bf5aca2
8 changed files with 426 additions and 536 deletions

View File

@ -11,7 +11,7 @@ from collections.abc import Callable
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from typing import Any, Iterable, Iterator, Literal, Self, Sequence
from typing import Any, Iterable, Iterator, Literal, Mapping, Self, Sequence
import numpy as np
from frozendict import frozendict
@ -67,7 +67,7 @@ class QubeNamedRoot:
return self.key
@dataclass(frozen=True, eq=True, order=True, unsafe_hash=True)
@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True)
class Qube:
key: str
values: ValueGroup
@ -77,6 +77,8 @@ class Qube:
children: tuple[Qube, ...] = ()
is_root: bool = False
is_leaf: bool = False
depth: int = field(default=0, compare=False)
shape: tuple[int, ...] = field(default=(), compare=False)
@classmethod
def make_node(
@ -84,16 +86,20 @@ class Qube:
key: str,
values: Iterable | QEnum | WildcardGroup,
children: Iterable[Qube],
metadata: dict[str, np.ndarray] = {},
metadata: Mapping[str, np.ndarray] = {},
is_root: bool = False,
is_leaf: bool | None = None,
) -> Qube:
children = tuple(sorted(children, key=lambda n: ((n.key, n.values.min()))))
if isinstance(values, ValueGroup):
values = values
else:
values = QEnum(values)
if not isinstance(values, WildcardGroup) and not is_root:
assert len(values) > 0, "Nodes must have at least one value"
children = tuple(sorted(children, key=lambda n: ((n.key, n.values.min()))))
return cls(
key,
values=values,
@ -105,6 +111,14 @@ class Qube:
@classmethod
def make_root(cls, children: Iterable[Qube], metadata={}) -> Qube:
def update_depth_shape(children, depth, shape):
for child in children:
child.depth = depth + 1
child.shape = shape + (len(child.values),)
update_depth_shape(child.children, child.depth, child.shape)
update_depth_shape(children, depth=0, shape=(1,))
return cls.make_node(
"root",
values=QEnum(("root",)),
@ -127,7 +141,7 @@ class Qube:
return Qube.from_json(json.load(f))
@classmethod
def from_datacube(cls, datacube: dict[str, str | Sequence[str]]) -> Qube:
def from_datacube(cls, datacube: Mapping[str, str | Sequence[str]]) -> Qube:
key_vals = list(datacube.items())[::-1]
children: list[Qube] = []

View File

@ -1,3 +1,25 @@
"""
# Set Operations
The core of this is the observation that for two sets A and B, if we compute (A - B), (A B) amd (B - A)
then we can get the other operations by taking unions of the above three objects.
Union: All of them
Intersection: Just take A B
Difference: Take either A - B or B - A
Symmetric Difference (XOR): Take A - B and B - A
We start with a shallow implementation of this algorithm that only deals with a pair of nodes, not the whole tree:
shallow_set_operation(A: Qube, B: Qube) -> SetOpsResult
This takes two qubes and (morally) returns (A - B), (A B) amd (B - A) but only for the values and metadata at the top level.
For technical reasons that will become clear we actually return a struct with two copies of (A B). One has the metadata from A and the children of A call it A', and the other has them from B call it B'. This is relevant when we extend the shallow algorithm to work with a whole tree because we will recurse and compute the set operation for each pair of the children of A' and B'.
NB: Currently there are two kinds of values, QEnums, that store a list of values and Wildcards that 'match with everything'. shallow_set_operation checks the type of values and dispatches to different methods depending on the combination of types it finds.
"""
from __future__ import annotations
from collections import defaultdict
@ -17,6 +39,8 @@ if TYPE_CHECKING:
class SetOperation(Enum):
"Map from set operations to which combination of (A - B), (A ∩ B), (B - A) we need."
UNION = (1, 1, 1)
INTERSECTION = (0, 1, 0)
DIFFERENCE = (1, 0, 0)
@ -24,78 +48,145 @@ class SetOperation(Enum):
@dataclass(eq=True, frozen=True)
class ValuesMetadata:
class ValuesIndices:
"Helper class to hold the values and indices from a node."
values: ValueGroup
indices: list[int] | slice
indices: tuple[int, ...]
@classmethod
def from_values(cls, values: ValueGroup):
return cls(values=values, indices=tuple(range(len(values))))
@classmethod
def empty(cls):
return cls(values=QEnum([]), indices=())
def enumerate(self) -> Iterable[tuple[Any, int]]:
return zip(self.indices, self.values)
def QEnum_intersection(
A: ValuesMetadata,
B: ValuesMetadata,
) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
intersection: dict[Any, int] = {}
just_A: dict[Any, int] = {}
just_B: dict[Any, int] = {val: i for i, val in enumerate(B.values)}
for index_a, val_A in enumerate(A.values):
if val_A in B.values:
just_B.pop(val_A)
intersection[val_A] = (
index_a # We throw away any overlapping metadata from B
)
else:
just_A[val_A] = index_a
intersection_out = ValuesMetadata(
values=QEnum(list(intersection.keys())),
indices=list(intersection.values()),
def get_indices(
metadata: frozendict[str, np.ndarray], indices: tuple[int, ...]
) -> frozendict[str, np.ndarray]:
"Given a metadata dict and some indices, return a new metadata dict with only the values indexed by the indices"
return frozendict(
{k: v[..., indices] for k, v in metadata.items() if isinstance(v, np.ndarray)}
)
just_A_out = ValuesMetadata(
values=QEnum(list(just_A.keys())),
indices=list(just_A.values()),
@dataclass(eq=True, frozen=True)
class SetOpResult:
"""
Given two sets A and B, all possible set operations can be constructed from A - B, A B, B - A
That is, what's only in A, the intersection and what's only in B
However because we need to recurse on children we actually return two intersection node:
only_A is a qube with:
The values in A but not in B
The metadata corresponding to this values
All the children A had
intersection_A is a qube with:
The values that intersected with B
The metadata from that intersection
All the children A had
And vice versa for only_B and intersection B
"""
only_A: ValuesIndices
intersection_A: ValuesIndices
intersection_B: ValuesIndices
only_B: ValuesIndices
def shallow_qenum_set_operation(A: ValuesIndices, B: ValuesIndices) -> SetOpResult:
"""
For two sets of values, partition the overlap into four groups:
only_A: values and indices of values that are in A but not B
intersection_A: values and indices of values that are in both A and B
And vice versa for only_B and intersection_B.
Note that intersection_A and intersection_B contain the same values but the indices are different.
"""
# create four groups that map value -> index
only_A: dict[Any, int] = {val: i for i, val in A.enumerate()}
only_B: dict[Any, int] = {val: i for i, val in B.enumerate()}
intersection_A: dict[Any, int] = {}
intersection_B: dict[Any, int] = {}
# Go through all the values and move any that are in the intersection
# to the corresponding group, keeping the indices
for val in A.values:
if val in B.values:
intersection_A[val] = only_A.pop(val)
intersection_B[val] = only_B.pop(val)
def package(values_indices: dict[Any, int]) -> ValuesIndices:
return ValuesIndices(
values=QEnum(list(values_indices.keys())),
indices=tuple(values_indices.values()),
)
just_B_out = ValuesMetadata(
values=QEnum(list(just_B.keys())),
indices=list(just_B.values()),
return SetOpResult(
only_A=package(only_A),
only_B=package(only_B),
intersection_A=package(intersection_A),
intersection_B=package(intersection_B),
)
return just_A_out, intersection_out, just_B_out
def shallow_wildcard_set_operation(A: ValuesIndices, B: ValuesIndices) -> SetOpResult:
"""
WildcardGroups behave as if they contain all the values of whatever they match against.
For two wildcards we just return both.
For A == wildcard and B == enum we have to be more careful:
1. All of B is in the intersection so only_B is None too.
2. The wildcard may need to match against other things so only_A is A
3. We return B in the intersection_B and intersection_A slot.
def node_intersection(
A: ValuesMetadata,
B: ValuesMetadata,
) -> tuple[ValuesMetadata, ValuesMetadata, ValuesMetadata]:
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
return QEnum_intersection(A, B)
This last bit happens because the wildcard basically adopts the values of whatever it sees.
"""
# Two wildcard groups have full overlap.
if isinstance(A.values, WildcardGroup) and isinstance(B.values, WildcardGroup):
return (
ValuesMetadata(QEnum([]), []),
ValuesMetadata(WildcardGroup(), slice(None)),
ValuesMetadata(QEnum([]), []),
)
return SetOpResult(ValuesIndices.empty(), A, B, ValuesIndices.empty())
# If A is a wildcard matcher then the intersection is everything
# just_A is still *
# just_B is empty
# If A is a wildcard matcher and B is not
# then the intersection is everything from B
if isinstance(A.values, WildcardGroup):
return A, B, ValuesMetadata(QEnum([]), [])
return SetOpResult(A, B, B, ValuesIndices.empty())
# The reverse if B is a wildcard
# If B is a wildcard matcher and A is not
# then the intersection is everything from A
if isinstance(B.values, WildcardGroup):
return ValuesMetadata(QEnum([]), []), A, B
return SetOpResult(ValuesIndices.empty(), A, A, B)
raise NotImplementedError(
f"Fused set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
f"One of {type(A.values)} and {type(B.values)} should be WildCardGroup"
)
def shallow_set_operation(
A: ValuesIndices,
B: ValuesIndices,
) -> SetOpResult:
if isinstance(A.values, QEnum) and isinstance(B.values, QEnum):
return shallow_qenum_set_operation(A, B)
# WildcardGroups behave as if they contain all possible values.
if isinstance(A.values, WildcardGroup) or isinstance(B.values, WildcardGroup):
return shallow_wildcard_set_operation(A, B)
raise NotImplementedError(
f"Set operations on values types {type(A.values)} and {type(B.values)} not yet implemented"
)
def operation(
A: Qube, B: Qube, operation_type: SetOperation, node_type, depth=0
) -> Qube | None:
# print(f"operation({A}, {B})")
assert A.key == B.key, (
"The two Qube root nodes must have the same key to perform set operations,"
f"would usually be two root nodes. They have {A.key} and {B.key} respectively"
@ -122,14 +213,16 @@ def operation(
pushdown_metadata_B: dict[str, np.ndarray] = {}
for key in set(A.metadata.keys()) | set(B.metadata.keys()):
if key not in A.metadata:
raise ValueError(f"B has key {key} but A does not. {A = } {B = }")
if key not in B.metadata:
raise ValueError(f"A has key {key} but B does not. {A = } {B = }")
pushdown_metadata_B[key] = B.metadata[key]
continue
if key not in B.metadata:
pushdown_metadata_A[key] = A.metadata[key]
continue
# print(f"{key = } {A.metadata[key] = } {B.metadata[key]}")
A_val = A.metadata[key]
B_val = B.metadata[key]
if A_val == B_val:
if np.allclose(A_val, B_val):
# print(f"{' ' * depth}Keeping metadata key '{key}' at this level")
stayput_metadata[key] = A.metadata[key]
else:
@ -143,7 +236,6 @@ def operation(
# where d is the length of the node values
for node in A.children:
N = len(node.values)
# print(N)
meta = {
k: np.broadcast_to(v[..., np.newaxis], v.shape + (N,))
for k, v in pushdown_metadata_A.items()
@ -160,10 +252,12 @@ def operation(
node = node.replace(metadata=node.metadata | meta)
nodes_by_key[node.key][1].append(node)
# print(f"{nodes_by_key = }")
# For every node group, perform the set operation
for key, (A_nodes, B_nodes) in nodes_by_key.items():
output = list(
_operation(key, A_nodes, B_nodes, operation_type, node_type, depth + 1)
_operation(A_nodes, B_nodes, operation_type, node_type, depth + 1)
)
# print(f"{' '*depth}_operation {operation_type.name} {A_nodes} {B_nodes} out = [{output}]")
new_children.extend(output)
@ -183,6 +277,11 @@ def operation(
new_children = list(compress_children(new_children))
# The values and key are the same so we just replace the children
if A.key == "root":
return node_type.make_root(
children=new_children,
metadata=stayput_metadata,
)
return node_type.make_node(
key=node_key,
values=node_values,
@ -192,85 +291,86 @@ def operation(
)
def get_indices(metadata: dict[str, np.ndarray], indices: list[int] | slice):
return {
k: v[..., indices] for k, v in metadata.items() if isinstance(v, np.ndarray)
}
def _operation(
key: str,
A: list[Qube],
B: list[Qube],
operation_type: SetOperation,
node_type,
depth: int,
) -> Iterable[Qube]:
keep_just_A, keep_intersection, keep_just_B = operation_type.value
"""
This operation assumes that we've found two nodes that match and now want to do a set operation on their children. Hence we take in two lists of child nodes all of which have the same key but different values.
We then loop over all pairs of children from each list and compute the intersection.
"""
# print(f"_operation({A}, {B})")
keep_only_A, keep_intersection, keep_only_B = operation_type.value
values = {}
for node in A + B:
values[node] = ValuesMetadata(node.values, node.metadata)
# We're going to progressively remove values from the starting nodes as we do intersections
# So we make a node -> ValuesIndices mapping here for both a and b
only_a: dict[Qube, ValuesIndices] = {
n: ValuesIndices.from_values(n.values) for n in A
}
only_b: dict[Qube, ValuesIndices] = {
n: ValuesIndices.from_values(n.values) for n in B
}
# Iterate over all pairs (node_A, node_B)
def make_new_node(source: Qube, values_indices: ValuesIndices):
return source.replace(
values=values_indices.values,
metadata=get_indices(source.metadata, values_indices.indices),
)
# Iterate over all pairs (node_A, node_B) and perform the shallow set operation
# Update our copy of the original node to remove anything that appears in an intersection
for node_a in A:
for node_b in B:
# Compute A - B, A & B, B - A
# Update the values for the two source nodes to remove the intersection
just_a, intersection, just_b = node_intersection(
values[node_a],
values[node_b],
)
set_ops_result = shallow_set_operation(only_a[node_a], only_b[node_b])
# Remove the intersection from the source nodes
values[node_a] = just_a
values[node_b] = just_b
# Save reduced values back to nodes
only_a[node_a] = set_ops_result.only_A
only_b[node_b] = set_ops_result.only_B
if keep_intersection:
if intersection.values:
new_node_a = node_a.replace(
values=intersection.values,
metadata=get_indices(node_a.metadata, intersection.indices),
)
new_node_b = node_b.replace(
values=intersection.values,
metadata=get_indices(node_b.metadata, intersection.indices),
)
# print(f"{' '*depth}{node_a = }")
# print(f"{' '*depth}{node_b = }")
# print(f"{' '*depth}{intersection.values =}")
if (
set_ops_result.intersection_A.values
and set_ops_result.intersection_B.values
):
result = operation(
new_node_a,
new_node_b,
make_new_node(node_a, set_ops_result.intersection_A),
make_new_node(node_b, set_ops_result.intersection_B),
operation_type,
node_type,
depth=depth + 1,
)
if result is not None:
# If we're doing a difference or xor we might want to throw away the intersection
# However we can only do this once we get to the leaf nodes, otherwise we'll
# throw away nodes too early!
# Consider Qube(root, a=1, b=1/2) - Qube(root, a=1, b=1)
# We can easily throw away the whole a node by accident here!
if keep_intersection or result.children:
yield result
# Now we've removed all the intersections we can yield the just_A and just_B parts if needed
if keep_just_A:
for node in A:
if values[node].values:
yield node_type.make_node(
key,
children=node.children,
values=values[node].values,
metadata=get_indices(node.metadata, values[node].indices),
)
if keep_just_B:
for node in B:
if values[node].values:
yield node_type.make_node(
key,
children=node.children,
values=values[node].values,
metadata=get_indices(node.metadata, values[node].indices),
elif (
not set_ops_result.intersection_A.values
and not set_ops_result.intersection_B.values
):
continue
else:
raise ValueError(
f"Only one of set_ops_result.intersection_A and set_ops_result.intersection_B is None, I didn't think that could happen! {set_ops_result = }"
)
if keep_only_A:
for node, vi in only_a.items():
if vi.values:
yield make_new_node(node, vi)
def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
if keep_only_B:
for node, vi in only_b.items():
if vi.values:
yield make_new_node(node, vi)
def compress_children(children: Iterable[Qube], depth=0) -> tuple[Qube, ...]:
"""
Helper method tht only compresses a set of nodes, and doesn't do it recursively.
Used in Qubed.compress but also to maintain compression in the set operations above.
@ -287,49 +387,78 @@ def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
# Now go through and create new compressed nodes for any groups that need collapsing
new_children = []
for child_list in identical_children.values():
if len(child_list) > 1:
# If the group is size one just keep it
if len(child_list) == 1:
new_child = child_list.pop()
else:
example = child_list[0]
node_type = type(example)
key = child_list[0].key
value_type = type(example.values)
# Compress the children into a single node
assert all(isinstance(child.values, QEnum) for child in child_list), (
"All children must have QEnum values"
assert all(isinstance(child.values, value_type) for child in child_list), (
f"All nodes to be grouped must have the same value type, expected {value_type}"
)
metadata_groups = {
k: [child.metadata[k] for child in child_list]
for k in example.metadata.keys()
}
# We know the children of this group of nodes all have the same structure
# but we still need to merge the metadata across them
# children = example.children
children = merge_metadata(child_list, example.depth)
metadata: frozendict[str, np.ndarray] = frozendict(
{
k: np.concatenate(metadata_group, axis=-1)
for k, metadata_group in metadata_groups.items()
}
)
# Do we need to recusively compress here?
# children = compress_children(children, depth=depth+1)
children = [cc for c in child_list for cc in c.children]
compressed_children = compress_children(children)
new_child = node_type.make_node(
key=key,
metadata=metadata,
values=QEnum(set(v for child in child_list for v in child.values)),
children=compressed_children,
)
if value_type is QEnum:
values = QEnum(set(v for child in child_list for v in child.values))
elif value_type is WildcardGroup:
values = example.values
else:
# If the group is size one just keep it
new_child = child_list.pop()
raise ValueError(f"Unknown value type: {value_type}")
new_child = node_type.make_node(
key=example.key,
metadata=example.metadata,
values=values,
children=children,
)
new_children.append(new_child)
return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min()))))
def union(a: Qube, b: Qube) -> Qube:
return operation(
a,
b,
SetOperation.UNION,
type(a),
def merge_metadata(qubes: list[Qube], axis) -> Iterable[Qube]:
"""
Given a list of qubes with identical structure,
match up the children of each node and merge the metadata
"""
# Group the children of each qube and merge them
# Exploit the fact that they have the same shape and ordering
example = qubes[0]
node_type = type(example)
for i in range(len(example.children)):
group = [q.children[i] for q in qubes]
group_example = group[0]
assert len(set((c.structural_hash for c in group))) == 1
# Collect metadata by key
metadata_groups = {
k: [q.metadata[k] for q in group] for k in group_example.metadata.keys()
}
# Concatenate the metadata together
metadata: frozendict[str, np.ndarray] = frozendict(
{
k: np.concatenate(metadata_group, axis=axis)
for k, metadata_group in metadata_groups.items()
}
)
group_children = merge_metadata(group, axis)
yield node_type.make_node(
key=group_example.key,
metadata=metadata,
values=group_example.values,
children=group_children,
)

View File

@ -1,334 +0,0 @@
#![allow(dead_code)]
use std::rc::Rc;
use smallstr::SmallString;
use slotmap::{new_key_type, SlotMap};
new_key_type! {
struct NodeId;
}
type CompactString = SmallString<[u8; 16]>;
#[derive(Clone)]
enum NodeValueTypes {
String(CompactString),
Int(i32),
}
impl From<&str> for NodeValueTypes {
fn from(s: &str) -> Self {
NodeValueTypes::String(CompactString::from(s))
}
}
impl From<i32> for NodeValueTypes {
fn from(i: i32) -> Self {
NodeValueTypes::Int(i)
}
}
enum NodeValue {
Single(NodeValueTypes),
Multiple(Vec<NodeValueTypes>),
}
struct Node<Payload> {
key: Rc<String>,
value: NodeValue,
parent: Option<NodeId>,
prev_sibling: Option<NodeId>,
next_sibling: Option<NodeId>,
// vector may be faster for traversal, but linkedlist should be faster for insertion
children: Option<(NodeId, NodeId)>, // (first_child, last_child)
data: Option<Payload>,
}
struct QueryTree<Payload> {
nodes: SlotMap<NodeId, Node<Payload>>,
}
impl<Payload> QueryTree<Payload> {
fn new() -> Self {
QueryTree {
nodes: SlotMap::with_key(),
}
}
// Adds a node with a key and single value
fn add_node<S>(&mut self, key: &Rc<String>, value: S, parent: Option<NodeId>) -> NodeId
where
S: Into<NodeValueTypes>,
{
let node_id = self.nodes.insert_with_key(|_| Node {
key: Rc::clone(key),
value: NodeValue::Single(value.into()),
parent,
prev_sibling: None,
next_sibling: None,
children: None,
data: None,
});
if let Some(parent_id) = parent {
// Determine if parent has existing children
if let Some((first_child_id, last_child_id)) = self.nodes[parent_id].children {
// Update the last child's `next_sibling`
{
let last_child = &mut self.nodes[last_child_id];
last_child.next_sibling = Some(node_id);
}
// Update the new node's `prev_sibling`
{
let new_node = &mut self.nodes[node_id];
new_node.prev_sibling = Some(last_child_id);
}
// Update parent's last child
let parent_node = &mut self.nodes[parent_id];
parent_node.children = Some((first_child_id, node_id));
} else {
// No existing children
let parent_node = &mut self.nodes[parent_id];
parent_node.children = Some((node_id, node_id));
}
}
node_id
}
// Add a single value to a node
fn add_value<S>(&mut self, node_id: NodeId, value: S)
where
S: Into<NodeValueTypes>,
{
if let Some(node) = self.nodes.get_mut(node_id) {
match &mut node.value {
NodeValue::Single(v) => {
let values = vec![v.clone(), value.into()];
node.value = NodeValue::Multiple(values);
}
NodeValue::Multiple(values) => {
values.push(value.into());
}
}
}
}
// Add multiple values to a node
fn add_values<S>(&mut self, node_id: NodeId, values: Vec<S>)
where
S: Into<NodeValueTypes>,
{
if let Some(node) = self.nodes.get_mut(node_id) {
match &mut node.value {
NodeValue::Single(v) => {
let mut new_values = vec![v.clone()];
new_values.extend(values.into_iter().map(|v| v.into()));
node.value = NodeValue::Multiple(new_values);
}
NodeValue::Multiple(existing_values) => {
existing_values.extend(values.into_iter().map(|v| v.into()));
}
}
}
}
fn get_node(&self, node_id: NodeId) -> Option<&Node<Payload>> {
self.nodes.get(node_id)
}
// TODO: better if this returns an iterator?
fn get_children(&self, node_id: NodeId) -> Vec<NodeId> {
let mut children = Vec::new();
if let Some(node) = self.get_node(node_id) {
if let Some((first_child_id, _)) = node.children {
let mut current_id = Some(first_child_id);
while let Some(cid) = current_id {
children.push(cid);
current_id = self.nodes[cid].next_sibling;
}
}
}
children
}
fn remove_node(&mut self, node_id: NodeId) {
// Remove the node and update parent and siblings
if let Some(node) = self.nodes.remove(node_id) {
// Update parent's children
if let Some(parent_id) = node.parent {
let parent_node = self.nodes.get_mut(parent_id).unwrap();
if let Some((first_child_id, last_child_id)) = parent_node.children {
if first_child_id == node_id && last_child_id == node_id {
// Node was the only child
parent_node.children = None;
} else if first_child_id == node_id {
// Node was the first child
parent_node.children = Some((node.next_sibling.unwrap(), last_child_id));
} else if last_child_id == node_id {
// Node was the last child
parent_node.children = Some((first_child_id, node.prev_sibling.unwrap()));
}
}
}
// Update siblings
if let Some(prev_id) = node.prev_sibling {
self.nodes[prev_id].next_sibling = node.next_sibling;
}
if let Some(next_id) = node.next_sibling {
self.nodes[next_id].prev_sibling = node.prev_sibling;
}
// Recursively remove children
let children_ids = self.get_children(node_id);
for child_id in children_ids {
self.remove_node(child_id);
}
}
}
fn is_root(&self, node_id: NodeId) -> bool {
self.nodes[node_id].parent.is_none()
}
fn is_leaf(&self, node_id: NodeId) -> bool {
self.nodes[node_id].children.is_none()
}
fn add_payload(&mut self, node_id: NodeId, payload: Payload) {
if let Some(node) = self.nodes.get_mut(node_id) {
node.data = Some(payload);
}
}
fn print_tree(&self) {
// Find all root nodes (nodes without a parent)
let roots: Vec<NodeId> = self
.nodes
.iter()
.filter_map(|(id, node)| {
if node.parent.is_none() {
Some(id)
} else {
None
}
})
.collect();
// Iterate through each root node and print its subtree
for (i, root_id) in roots.iter().enumerate() {
let is_last = i == roots.len() - 1;
self.print_node(*root_id, String::new(), is_last);
}
}
/// Recursively prints a node and its children.
///
/// - `node_id`: The current node's ID.
/// - `prefix`: The string prefix for indentation and branch lines.
/// - `is_last`: Boolean indicating if the node is the last child of its parent.
fn print_node(&self, node_id: NodeId, prefix: String, is_last: bool) {
// Retrieve the current node
let node = match self.nodes.get(node_id) {
Some(n) => n,
None => return, // Node not found; skip
};
// Determine the branch character
let branch = if prefix.is_empty() {
"" // Root node doesn't have a branch
} else if is_last {
"└── " // Last child
} else {
"├── " // Middle child
};
// Print the current node's key and values
print!("{}{}{}", prefix, branch, node.key);
match &node.value {
NodeValue::Single(v) => match v {
NodeValueTypes::String(s) => println!(": ({})", s),
NodeValueTypes::Int(i) => println!(": ({})", i),
},
NodeValue::Multiple(vs) => {
let values: Vec<String> = vs
.iter()
.map(|v| match v {
NodeValueTypes::String(s) => s.to_string(),
NodeValueTypes::Int(i) => i.to_string(),
})
.collect();
println!(": ({})", values.join(", "));
}
}
// Prepare the prefix for child nodes
let new_prefix = if prefix.is_empty() {
if is_last {
" ".to_string()
} else {
"".to_string()
}
} else {
if is_last {
format!("{} ", prefix)
} else {
format!("{}", prefix)
}
};
// Retrieve and iterate through child nodes
if let Some((_first_child_id, _last_child_id)) = node.children {
let children = self.get_children(node_id);
let total = children.len();
for (i, child_id) in children.iter().enumerate() {
let child_is_last = i == total - 1;
self.print_node(*child_id, new_prefix.clone(), child_is_last);
}
}
}
}
fn main() {
let mut tree: QueryTree<i16> = QueryTree::new();
let value = "hello";
let axis = Rc::new("foo".to_string());
let root_id = tree.add_node(&axis, value, None);
use std::time::Instant;
let now = Instant::now();
for _ in 0..100 {
// let child_value = format!("child_val{}", i);
let child_id = tree.add_node(&axis, value, Some(root_id));
// tree.add_value(child_id, value);
for _ in 0..100 {
// let gchild_value = format!("gchild_val{}", j);
let gchild_id = tree.add_node(&axis, value, Some(child_id));
// tree.add_values(gchild_id, vec![1, 2]);
for _ in 0..1000 {
// let ggchild_value = format!("ggchild_val{}", k);
let _ggchild_id = tree.add_node(&axis, value, Some(gchild_id));
// tree.add_value(_ggchild_id, value);
// tree.add_values(_ggchild_id, vec![1, 2, 3, 4]);
}
}
}
assert_eq!(tree.nodes.len(), 10_010_101);
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
// tree.print_tree();
}

View File

@ -147,9 +147,9 @@ impl Qube {
StringId(self.strings.get_or_intern(val))
}
pub(crate) fn add_node(&mut self, parent: NodeId, key: &str, values: &[&str]) -> NodeId {
pub(crate) fn add_node(&mut self, parent: NodeId, key: &str, values: impl IntoIterator<Item = impl AsRef<str>>) -> NodeId {
let key_id = self.get_or_intern(key);
let values = values.iter().map(|val| self.get_or_intern(val)).collect();
let values = values.into_iter().map(|val| self.get_or_intern(val.as_ref())).collect();
// Create the node object
let node = Node {

View File

@ -78,6 +78,8 @@ pub enum OneOrMany<T> {
Many(Vec<T>),
}
// Todo: Is there a way to rewrite this so that is doesn't allocate?
// Perhaps by returning an iterator?
impl<T> Into<Vec<T>> for OneOrMany<T> {
fn into(self) -> Vec<T> {
match self {
@ -108,10 +110,8 @@ impl Qube {
// massage values from T | Vec<T> into Vec<T>
let values: Vec<String> = values.into();
let values_refs: Vec<&str> = values.iter().map(String::as_str).collect();
let mut q = slf.borrow_mut();
let node_id = q.add_node(parent.id, key, &values_refs);
let node_id = q.add_node(parent.id, key, &values);
Ok(PyNodeRef { id: node_id, qube: slf.into()})
}

View File

@ -1,8 +1,33 @@
from frozendict import frozendict
# from frozendict import frozendict
# from qubed import Qube
def make_set(entries):
return set((frozendict(a), frozendict(b)) for a, b in entries)
# def make_set(entries):
# return set((frozendict(a), frozendict(b)) for a, b in entries)
# def construction():
# q = Qube.from_nodes(
# {
# "class": dict(values=["od", "rd"]),
# "expver": dict(values=[1, 2]),
# "stream": dict(
# values=["a", "b", "c"], metadata=dict(number=list(range(12)))
# ),
# }
# )
# assert make_set(q.leaves_with_metadata()) == make_set([
# ({'class': 'od', 'expver': 1, 'stream': 'a'}, {'number': 0}),
# ({'class': 'od', 'expver': 1, 'stream': 'b'}, {'number': 1}),
# ({'class': 'od', 'expver': 1, 'stream': 'c'}, {'number': 2}),
# ({'class': 'od', 'expver': 2, 'stream': 'a'}, {'number': 3}),
# ({'class': 'od', 'expver': 2, 'stream': 'b'}, {'number': 4}),
# ({'class': 'od', 'expver': 2, 'stream': 'c'}, {'number': 5}),
# ({'class': 'rd', 'expver': 1, 'stream': 'a'}, {'number': 6}),
# ({'class': 'rd', 'expver': 1, 'stream': 'b'}, {'number': 7}),
# ({'class': 'rd', 'expver': 1, 'stream': 'c'}, {'number': 8}),
# ({'class': 'rd', 'expver': 2, 'stream': 'a'}, {'number': 9}),
# ({'class': 'rd', 'expver': 2, 'stream': 'b'}, {'number': 10}),
# ({'class': 'rd', 'expver': 2, 'stream': 'c'}, {'number': 11})])
# def test_simple_union():
@ -42,3 +67,32 @@ def make_set(entries):
# assert make_set(expected_union.leaves_with_metadata()) == make_set(
# union.leaves_with_metadata()
# )
# def test_construction_from_fdb():
# import json
# paths = {}
# current_path = None
# i = 0
# qube = Qube.empty()
# with open("tests/data/climate_dt_paths.json") as f:
# for l in f.readlines():
# i += 1
# j = json.loads(l)
# if "type" in j and j["type"] == "path":
# paths[j["i"]] = j["path"]
# else:
# request = j.pop("keys")
# metadata = j
# # print(request, metadata)
# q = Qube.from_nodes({
# key : dict(values = [value])
# for key, value in request.items()
# }).add_metadata(**metadata)
# qube = qube | q
# if i > 100: break

View File

@ -1,12 +1,18 @@
from qubed import Qube
def set_operation_testcase(testcase):
def set_operation_testcase(name, testcase):
q1 = Qube.from_tree(testcase["q1"])
q2 = Qube.from_tree(testcase["q2"])
assert q1 | q2 == Qube.from_tree(testcase["union"])
assert q1 & q2 == Qube.from_tree(testcase["intersection"])
assert q1 - q2 == Qube.from_tree(testcase["q1 - q2"])
assert q1 | q2 == Qube.from_tree(testcase["union"]), (
f"Case: {name} Op: Union\n {q1 = }\n {q2 = }\n {q1 | q2 = }\n expected = {testcase['union']}\n"
)
assert q1 & q2 == Qube.from_tree(testcase["intersection"]), (
f"Case: {name} Op: Intersection\n {q1 = }\n {q2 = }\n {q1 - q2 = }\n expected = {testcase['intersection']}\n"
)
assert q1 - q2 == Qube.from_tree(testcase["difference"]), (
f"Case: {name} Op: Difference\n {q1 = }\n {q2 = }\n {q1 - q2 = }\n expected = {testcase['difference']}\n"
)
# These are a bunch of testcases where q1 and q2 are specified and then their union/intersection/difference are checked
@ -19,30 +25,27 @@ def set_operation_testcase(testcase):
# "q2": str(q2),
# "union": str(q1 | q2),
# "intersection": str(q1 & q2),
# "q1 - q2": str(q1 - q2),
# "difference": str(q1 - q2),
# }
# BUT MANUALLY CHECK THE OUTPUT BEFORE ADDING IT AS A TEST CASE!
testcases = [
# Simplest case, only leaves differ
{
testcases = {
"Simplest case, only leaves differ": {
"q1": "root, a=1, b=1, c=1",
"q2": "root, a=1, b=1, c=2",
"union": "root, a=1, b=1, c=1/2",
"intersection": "root",
"q1 - q2": "root",
"difference": "root, a=1, b=1, c=1",
},
# Some overlap but also each tree has unique items
{
"Some overlap but also each tree has unique items": {
"q1": "root, a=1, b=1, c=1/2/3",
"q2": "root, a=1, b=1, c=2/3/4",
"union": "root, a=1, b=1, c=1/2/3/4",
"intersection": "root, a=1, b=1, c=2/3",
"q1 - q2": "root",
"difference": "root, a=1, b=1, c=1",
},
# Overlap at two levels
{
"Overlap at two levels": {
"q1": "root, a=1, b=1/2, c=1/2/3",
"q2": "root, a=1, b=2/3, c=2/3/4",
"union": """
@ -52,26 +55,48 @@ testcases = [
b=3, c=2/3/4
""",
"intersection": "root, a=1, b=2, c=2/3",
"q1 - q2": "root",
"difference": """
root, a=1
b=1, c=1/2/3
b=2, c=1""",
},
# Check that we can merge even if the divergence point is higher
{
"Simple difference": {
"q1": "root, a=1, b=1, c=1/2/3",
"q2": "root, a=1, b=1, c=2",
"union": "root, a=1, b=1, c=1/2/3",
"intersection": "root, a=1, b=1, c=2",
"difference": "root, a=1, b=1, c=1/3",
},
"Check that we can merge even if the divergence point is higher": {
"q1": "root, a=1, b=1, c=1",
"q2": "root, a=2, b=1, c=1",
"union": "root, a=1/2, b=1, c=1",
"intersection": "root",
"q1 - q2": "root, a=1, b=1, c=1",
"difference": "root, a=1, b=1, c=1",
},
# Two equal qubes
{
"Two equal qubes": {
"q1": "root, a=1, b=1, c=1",
"q2": "root, a=1, b=1, c=1",
"union": "root, a=1, b=1, c=1",
"intersection": "root, a=1, b=1, c=1",
"q1 - q2": "root",
"difference": "root",
},
# With wildcards
{
"Two qubes that don't compress on their own but the union does": {
"q1": """
root
a=1/3, b=1
a=2, b=1/2
""",
"q2": "root, a=1/3, b=2",
"union": "root, a=1/2/3, b=1/2",
"intersection": "root",
"difference": """
root
a=1/3, b=1
a=2, b=1/2
""",
},
"With wildcards": {
"q1": "root, frequency=*, levtype=*, param=*, levelist=*, domain=a/b/c/d",
"q2": "root, frequency=*, levtype=*, param=*, domain=a/b/c/d",
"union": """
@ -80,14 +105,21 @@ testcases = [
levelist=*, domain=a/b/c/d
""",
"intersection": "root",
"q1 - q2": "root",
"difference": "root, frequency=*, levtype=*, param=*, levelist=*, domain=a/b/c/d",
},
]
"Merging wildcard groups": {
"q1": "root, levtype=pl, param=q, levelist=100/1000, quantile=*",
"q2": "root, levtype=pl, param=t, levelist=100/1000, quantile=*",
"union": "root, levtype=pl, param=q/t, levelist=100/1000, quantile=*",
"intersection": "root",
"difference": "root, levtype=pl, param=q, levelist=100/1000, quantile=*",
},
}
def test_cases():
for case in testcases:
set_operation_testcase(case)
for name, case in testcases.items():
set_operation_testcase(name, case)
def test_leaf_conservation():

View File

@ -1,17 +1,12 @@
from qubed import Qube
q = Qube.from_dict(
{
"class=od": {
"expver=0001": {"param=1": {}, "param=2": {}},
"expver=0002": {"param=1": {}, "param=2": {}},
},
"class=rd": {
"expver=0001": {"param=1": {}, "param=2": {}, "param=3": {}},
"expver=0002": {"param=1": {}, "param=2": {}},
},
}
)
q = Qube.from_tree("""
root
class=od, expver=0001/0002, param=1/2
class=rd
expver=0001, param=1/2/3
expver=0002, param=1/2
""")
wild_datacube = {
"class": "*",