diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index 7b05b94..0ff9966 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -1,3 +1,8 @@ +# This causes python types to be evaluated later, +# allowing you to reference types like Qube inside the definion of the Qube class +# without having to do "Qube" +from __future__ import annotations + import dataclasses import functools import json @@ -6,8 +11,9 @@ from collections.abc import Callable from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Iterable, Iterator, Literal, Sequence +from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence +import numpy as np from frozendict import frozendict from . import set_operations @@ -17,13 +23,18 @@ from .tree_formatters import ( node_tree_to_html, node_tree_to_string, ) -from .value_types import QEnum, ValueGroup, WildcardGroup, values_from_json +from .value_types import ( + QEnum, + ValueGroup, + WildcardGroup, + values_from_json, +) @dataclass(frozen=False, eq=True, order=True, unsafe_hash=True) class Qube: data: NodeData - children: tuple["Qube", ...] + children: tuple[Qube, ...] @property def key(self) -> str: @@ -34,10 +45,10 @@ class Qube: return self.data.values @property - def metadata(self) -> frozendict[str, Any]: + def metadata(self) -> Mapping[str, np.ndarray]: return self.data.metadata - def replace(self, **kwargs) -> "Qube": + def replace(self, **kwargs) -> Qube: data_keys = { k: v for k, v in kwargs.items() if k in ["key", "values", "metadata"] } @@ -55,41 +66,41 @@ class Qube: return self.data.summary() @classmethod - def make(cls, key: str, values: ValueGroup, children, **kwargs) -> "Qube": + def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube: return cls( data=NodeData(key, values, metadata=kwargs.get("metadata", frozendict())), children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))), ) @classmethod - def root_node(cls, children: Iterable["Qube"]) -> "Qube": + def root_node(cls, children: Iterable[Qube]) -> Qube: return cls.make("root", QEnum(("root",)), children) @classmethod - def load(cls, path: str | Path) -> "Qube": + def load(cls, path: str | Path) -> Qube: with open(path, "r") as f: return Qube.from_json(json.load(f)) @classmethod - def from_datacube(cls, datacube: dict[str, str | Sequence[str]]) -> "Qube": + def from_datacube(cls, datacube: dict[str, str | Sequence[str]]) -> Qube: key_vals = list(datacube.items())[::-1] - children: list["Qube"] = [] + children: list[Qube] = [] for key, values in key_vals: + values_group: ValueGroup if values == "*": - values = WildcardGroup() - elif not isinstance(values, list): - values = [values] + values_group = WildcardGroup() + elif isinstance(values, list): + values_group = QEnum(values) + else: + values_group = QEnum([values]) - if isinstance(values, list): - values = QEnum(values) - - children = [cls.make(key, values, children)] + children = [cls.make(key, values_group, children)] return cls.root_node(children) @classmethod - def from_json(cls, json: dict) -> "Qube": + def from_json(cls, json: dict) -> Qube: def from_json(json: dict) -> Qube: return Qube.make( key=json["key"], @@ -112,7 +123,7 @@ class Qube: return to_json(self) @classmethod - def from_dict(cls, d: dict) -> "Qube": + def from_dict(cls, d: dict) -> Qube: def from_dict(d: dict) -> Iterator[Qube]: for k, children in d.items(): key, values = k.split("=") @@ -131,8 +142,8 @@ class Qube: return Qube.root_node(list(from_dict(d))) def to_dict(self) -> dict: - def to_dict(q: "Qube") -> tuple[str, dict]: - key = f"{q.key}={','.join(str(v) for v in q.values.values)}" + def to_dict(q: Qube) -> tuple[str, dict]: + key = f"{q.key}={','.join(str(v) for v in q.values)}" return key, dict(to_dict(c) for c in q.children) return to_dict(self)[1] @@ -189,10 +200,10 @@ class Qube: return cls.from_dict(root) @classmethod - def empty(cls) -> "Qube": + def empty(cls) -> Qube: return Qube.root_node([]) - def __str__(self, depth=None, name=None) -> str: + def __str_helper__(self, depth=None, name=None) -> str: node = ( dataclasses.replace( self, @@ -203,8 +214,11 @@ class Qube: ) return "".join(node_tree_to_string(node=node, depth=depth)) + def __str__(self): + return self.__str_helper__() + def print(self, depth=None, name: str | None = None): - print(self.__str__(depth=depth, name=name)) + print(self.__str_helper__(depth=depth, name=name)) def html(self, depth=2, collapse=True, name: str | None = None) -> HTML: node = ( @@ -221,27 +235,27 @@ class Qube: return node_tree_to_html(self, depth=2, collapse=True) # Allow "key=value/value" / qube to prepend keys - def __rtruediv__(self, other: str) -> "Qube": + def __rtruediv__(self, other: str) -> Qube: key, values = other.split("=") - values = QEnum((values.split("/"))) - return Qube.root_node([Qube.make(key, values, self.children)]) + values_enum = QEnum((values.split("/"))) + return Qube.root_node([Qube.make(key, values_enum, self.children)]) - def __or__(self, other: "Qube") -> "Qube": + def __or__(self, other: Qube) -> Qube: return set_operations.operation( self, other, set_operations.SetOperation.UNION, type(self) ) - def __and__(self, other: "Qube") -> "Qube": + def __and__(self, other: Qube) -> Qube: return set_operations.operation( self, other, set_operations.SetOperation.INTERSECTION, type(self) ) - def __sub__(self, other: "Qube") -> "Qube": + def __sub__(self, other: Qube) -> Qube: return set_operations.operation( self, other, set_operations.SetOperation.DIFFERENCE, type(self) ) - def __xor__(self, other: "Qube") -> "Qube": + def __xor__(self, other: Qube) -> Qube: return set_operations.operation( self, other, set_operations.SetOperation.SYMMETRIC_DIFFERENCE, type(self) ) @@ -270,11 +284,10 @@ class Qube: def leaves_with_metadata( self, indices=() - ) -> Iterable[tuple[dict[str, str], dict[str, str]]]: + ) -> Iterator[tuple[dict[str, str], dict[str, str | np.ndarray]]]: if self.key == "root": for c in self.children: - for leaf in c.leaves_with_metadata(indices=()): - yield leaf + yield from c.leaves_with_metadata(indices=()) return for index, value in enumerate(self.values): @@ -305,21 +318,21 @@ class Qube: yield from to_list_of_cubes(c) if not node.children: - yield {node.key: list(node.values.values)} + yield {node.key: list(node.values)} for c in node.children: for sub_cube in to_list_of_cubes(c): - yield {node.key: list(node.values.values)} | sub_cube + yield {node.key: list(node.values)} | sub_cube return to_list_of_cubes(self) - def __getitem__(self, args) -> "Qube": + def __getitem__(self, args) -> Qube: if isinstance(args, str): specifiers = args.split(",") current = self for specifier in specifiers: - key, values = specifier.split("=") - values = values.split("/") + key, values_str = specifier.split("=") + values = values_str.split("/") for c in current.children: if c.key == key and set(values) == set(c.values): current = c @@ -354,7 +367,7 @@ class Qube: return 0 return 1 + sum(c.n_nodes for c in self.children) - def transform(self, func: "Callable[[Qube], Qube | Iterable[Qube]]") -> "Qube": + def transform(self, func: "Callable[[Qube], Qube | Iterable[Qube]]") -> Qube: """ Call a function on every node of the Qube, return one or more nodes. If multiple nodes are returned they each get a copy of the (transformed) children of the original node. @@ -375,8 +388,8 @@ class Qube: def remove_by_key(self, keys: str | list[str]): _keys: list[str] = keys if isinstance(keys, list) else [keys] - def remove_key(node: "Qube") -> "Qube": - children = [] + def remove_key(node: Qube) -> Qube: + children: list[Qube] = [] for c in node.children: if c.key in _keys: grandchildren = tuple(sorted(remove_key(cc) for cc in c.children)) @@ -405,17 +418,23 @@ class Qube: mode: Literal["strict", "relaxed"] = "relaxed", prune=True, consume=False, - ) -> "Qube": - # make all values lists - selection: dict[str, list[str] | Callable[[Any], bool]] = { - k: v if isinstance(v, list | Callable) else [v] - for k, v in selection.items() - } + ) -> Qube: + # Find any bare str values and replace them with [str] + _selection: dict[str, list[str] | Callable[[Any], bool]] = {} + for k, v in selection.items(): + if isinstance(v, list): + _selection[k] = v + elif callable(v): + _selection[k] = v + else: + _selection[k] = [v] def not_none(xs): return tuple(x for x in xs if x is not None) - def select(node: Qube, selection: dict[str, list[str]]) -> Qube | None: + def select( + node: Qube, selection: dict[str, list[str] | Callable[[Any], bool]] + ) -> Qube | None: # If this node has no children but there are still parts of the request # that have not been consumed, then prune this whole branch if consume and not node.children and selection: @@ -442,13 +461,15 @@ class Qube: raise ValueError(f"Unknown mode argument {mode}") # If the key IS in the selection then check if the values match - if node.key in selection: + if node.key in _selection: # If the key is specified, check if any of the values match - selection_criteria = selection[node.key] - if isinstance(selection_criteria, Callable): + selection_criteria = _selection[node.key] + if callable(selection_criteria): values = QEnum((c for c in node.values if selection_criteria(c))) + elif isinstance(selection_criteria, list): + values = QEnum((c for c in selection_criteria if c in node.values)) else: - values = QEnum((c for c in selection[node.key] if c in node.values)) + raise ValueError(f"Unknown selection type {selection_criteria}") # Here modes don't matter because we've explicitly filtered on this key and found nothing if not values: @@ -468,11 +489,11 @@ class Qube: return node.replace( children=new_children, - metadata=self.metadata | {"is_leaf": not bool(new_children)}, + metadata=dict(self.metadata) | {"is_leaf": not bool(new_children)}, ) return self.replace( - children=not_none(select(c, selection) for c in self.children) + children=not_none(select(c, _selection) for c in self.children) ) def span(self, key: str) -> list[str]: @@ -508,7 +529,7 @@ class Qube: return hash_node(self) - def compress(self) -> "Qube": + def compress(self) -> Qube: """ This method is quite computationally heavy because of trees like this: root, class=d1, generation=1 @@ -519,7 +540,7 @@ class Qube: """ - def union(a: "Qube", b: "Qube") -> "Qube": + def union(a: Qube, b: Qube) -> Qube: b = type(self).root_node(children=(b,)) out = set_operations.operation( a, b, set_operations.SetOperation.UNION, type(self) @@ -528,6 +549,8 @@ class Qube: new_children = [c.compress() for c in self.children] if len(new_children) > 1: - new_children = functools.reduce(union, new_children, Qube.empty()).children + new_children = list( + functools.reduce(union, new_children, Qube.empty()).children + ) return self.replace(children=tuple(sorted(new_children))) diff --git a/src/python/qubed/convert.py b/src/python/qubed/convert.py index c7ba76e..248af30 100644 --- a/src/python/qubed/convert.py +++ b/src/python/qubed/convert.py @@ -8,8 +8,10 @@ def parse_key_value_pairs(text: str): for segment in text.split(","): if "=" not in segment: print(segment) - key, values = segment.split("=", 1) # Ensure split only happens at first "=" - values = values.split("/") + key, values_str = segment.split( + "=", 1 + ) # Ensure split only happens at first "=" + values = values_str.split("/") result[key] = values return result diff --git a/src/python/qubed/node_types.py b/src/python/qubed/node_types.py index bb819dc..7293687 100644 --- a/src/python/qubed/node_types.py +++ b/src/python/qubed/node_types.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field -from typing import Hashable +from typing import Mapping +import numpy as np from frozendict import frozendict from .value_types import ValueGroup @@ -10,7 +11,7 @@ from .value_types import ValueGroup class NodeData: key: str values: ValueGroup - metadata: dict[str, tuple[Hashable, ...]] = field( + metadata: Mapping[str, np.ndarray] = field( default_factory=frozendict, compare=False ) diff --git a/src/python/qubed/py.typed b/src/python/qubed/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/python/qubed/set_operations.py b/src/python/qubed/set_operations.py index 82462c7..859b128 100644 --- a/src/python/qubed/set_operations.py +++ b/src/python/qubed/set_operations.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import defaultdict from dataclasses import replace from enum import Enum @@ -11,7 +13,7 @@ from .node_types import NodeData from .value_types import QEnum, ValueGroup, WildcardGroup if TYPE_CHECKING: - from .qube import Qube + from .Qube import Qube class SetOperation(Enum): @@ -22,7 +24,7 @@ class SetOperation(Enum): def node_intersection( - A: "ValueGroup", B: "ValueGroup" + A: ValueGroup, B: ValueGroup ) -> tuple[ValueGroup, ValueGroup, ValueGroup]: if isinstance(A, QEnum) and isinstance(B, QEnum): set_A, set_B = set(A), set(B) @@ -49,7 +51,7 @@ def node_intersection( ) -def operation(A: "Qube", B: "Qube", operation_type: SetOperation, node_type) -> "Qube": +def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube: assert A.key == B.key, ( "The two Qube root nodes must have the same key to perform set operations," f"would usually be two root nodes. They have {A.key} and {B.key} respectively" @@ -60,13 +62,15 @@ def operation(A: "Qube", B: "Qube", operation_type: SetOperation, node_type) -> ) # Group the children of the two nodes by key - nodes_by_key = defaultdict(lambda: ([], [])) + nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict( + lambda: ([], []) + ) for node in A.children: nodes_by_key[node.key][0].append(node) for node in B.children: nodes_by_key[node.key][1].append(node) - new_children = [] + new_children: list[Qube] = [] # For every node group, perform the set operation for key, (A_nodes, B_nodes) in nodes_by_key.items(): @@ -77,16 +81,16 @@ def operation(A: "Qube", B: "Qube", operation_type: SetOperation, node_type) -> # Whenever we modify children we should recompress them # But since `operation` is already recursive, we only need to compress this level not all levels # Hence we use the non-recursive _compress method - new_children = compress_children(new_children) + new_children = list(compress_children(new_children)) # The values and key are the same so we just replace the children - return replace(A, children=new_children) + return A.replace(children=tuple(sorted(new_children))) # The root node is special so we need a helper method that we can recurse on def _operation( - key: str, A: list["Qube"], B: list["Qube"], operation_type: SetOperation, node_type -) -> Iterable["Qube"]: + key: str, A: list[Qube], B: list[Qube], operation_type: SetOperation, node_type +) -> Iterable[Qube]: keep_just_A, keep_intersection, keep_just_B = operation_type.value # Iterate over all pairs (node_A, node_B) @@ -128,7 +132,7 @@ def _operation( yield node_type.make(key, values[node], node.children) -def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]: +def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]: """ Helper method tht only compresses a set of nodes, and doesn't do it recursively. Used in Qubed.compress but also to maintain compression in the set operations above. @@ -138,16 +142,16 @@ def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]: identical_children = defaultdict(set) for child in children: # only care about the key and children of each node, ignore values - key = hash((child.key, tuple((cc.structural_hash for cc in child.children)))) - identical_children[key].add(child) + h = hash((child.key, tuple((cc.structural_hash for cc in child.children)))) + identical_children[h].add(child) # Now go through and create new compressed nodes for any groups that need collapsing new_children = [] for child_set in identical_children.values(): if len(child_set) > 1: - child_set = list(child_set) - node_type = type(child_set[0]) - key = child_set[0].key + child_list = list(child_set) + node_type = type(child_list[0]) + key = child_list[0].key # Compress the children into a single node assert all(isinstance(child.data.values, QEnum) for child in child_set), ( @@ -155,13 +159,11 @@ def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]: ) node_data = NodeData( - key=key, + key=str(key), metadata=frozendict(), # Todo: Implement metadata compression - values=QEnum( - (v for child in child_set for v in child.data.values.values) - ), + values=QEnum((v for child in child_set for v in child.data.values)), ) - new_child = node_type(data=node_data, children=child_set[0].children) + new_child = node_type(data=node_data, children=child_list[0].children) else: # If the group is size one just keep it new_child = child_set.pop() @@ -170,7 +172,7 @@ def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]: return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min())))) -def union(a: "Qube", b: "Qube") -> "Qube": +def union(a: Qube, b: Qube) -> Qube: return operation( a, b, diff --git a/src/python/qubed/tree_formatters.py b/src/python/qubed/tree_formatters.py index 45597a6..7c83af8 100644 --- a/src/python/qubed/tree_formatters.py +++ b/src/python/qubed/tree_formatters.py @@ -1,16 +1,11 @@ +from __future__ import annotations + import random from dataclasses import dataclass -from typing import Iterable, Protocol, Sequence, runtime_checkable +from typing import TYPE_CHECKING, Iterable - -@runtime_checkable -class TreeLike(Protocol): - @property - def children( - self, - ) -> Sequence["TreeLike"]: ... # Supports indexing like node.children[i] - - def summary(self) -> str: ... +if TYPE_CHECKING: + from .Qube import Qube @dataclass(frozen=True) @@ -22,8 +17,8 @@ class HTML: def summarize_node( - node: TreeLike, collapse=False, max_summary_length=50, **kwargs -) -> tuple[str, str, TreeLike]: + node: Qube, collapse=False, max_summary_length=50, **kwargs +) -> tuple[str, str, Qube]: """ Extracts a summarized representation of the node while collapsing single-child paths. Returns the summary string and the last node in the chain that has multiple children. @@ -50,7 +45,7 @@ def summarize_node( return ", ".join(summaries), ",".join(paths), node -def node_tree_to_string(node: TreeLike, prefix: str = "", depth=None) -> Iterable[str]: +def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[str]: summary, path, node = summarize_node(node) if depth is not None and depth <= 0: @@ -74,7 +69,7 @@ def node_tree_to_string(node: TreeLike, prefix: str = "", depth=None) -> Iterabl def _node_tree_to_html( - node: TreeLike, prefix: str = "", depth=1, connector="", **kwargs + node: Qube, prefix: str = "", depth=1, connector="", **kwargs ) -> Iterable[str]: summary, path, node = summarize_node(node, **kwargs) @@ -99,7 +94,7 @@ def _node_tree_to_html( def node_tree_to_html( - node: TreeLike, depth=1, include_css=True, include_js=True, css_id=None, **kwargs + node: Qube, depth=1, include_css=True, include_js=True, css_id=None, **kwargs ) -> str: if css_id is None: css_id = f"qubed-tree-{random.randint(0, 1000000)}" diff --git a/src/python/qubed/value_types.py b/src/python/qubed/value_types.py index 5236924..eb73b67 100644 --- a/src/python/qubed/value_types.py +++ b/src/python/qubed/value_types.py @@ -1,8 +1,19 @@ +from __future__ import annotations + import dataclasses from abc import ABC, abstractmethod from dataclasses import dataclass, replace from datetime import date, datetime, timedelta -from typing import TYPE_CHECKING, Any, FrozenSet, Iterable, Literal, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + FrozenSet, + Iterable, + Iterator, + Literal, + Sequence, + TypeVar, +) if TYPE_CHECKING: from .Qube import Qube @@ -30,23 +41,19 @@ class ValueGroup(ABC): "Return the minimum value in the group." pass - -@dataclass(frozen=True) -class FiniteValueGroup(ValueGroup, ABC): + @classmethod @abstractmethod - def __len__(self) -> int: - "Return how many values this group contains." + def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]: + "Given a list of strings, return a one or more ValueGroups of this type." pass @abstractmethod - def __iter__(self) -> Iterable[Any]: + def __iter__(self) -> Iterator: "Iterate over the values in the group." pass - @classmethod @abstractmethod - def from_strings(cls, values: Iterable[str]) -> list["ValueGroup"]: - "Given a list of strings, return a one or more ValueGroups of this type." + def __len__(self) -> int: pass @@ -55,7 +62,7 @@ EnumValuesType = FrozenSet[T] @dataclass(frozen=True, order=True) -class QEnum(FiniteValueGroup): +class QEnum(ValueGroup): """ The simplest kind of key value is just a list of strings. summary -> string1/string2/string.... @@ -81,8 +88,9 @@ class QEnum(FiniteValueGroup): def __contains__(self, value: Any) -> bool: return value in self.values - def from_strings(self, values: Iterable[str]) -> list["ValueGroup"]: - return [type(self)(tuple(values))] + @classmethod + def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]: + return [cls(tuple(values))] def min(self): return min(self.values) @@ -105,6 +113,16 @@ class WildcardGroup(ValueGroup): def min(self): return "*" + def __len__(self): + return None + + def __iter__(self): + return ["*"] + + @classmethod + def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]: + return [WildcardGroup()] + class DateEnum(QEnum): def summary(self) -> str: @@ -125,7 +143,7 @@ class Range(ValueGroup, ABC): def min(self): return self.start - def __iter__(self) -> Iterable[Any]: + def __iter__(self) -> Iterator[Any]: i = self.start while i <= self.end: yield i @@ -145,19 +163,19 @@ class DateRange(Range): def __len__(self) -> int: return (self.end - self.start) // self.step - def __iter__(self) -> Iterable[date]: + def __iter__(self) -> Iterator[date]: current = self.start while current <= self.end if self.step.days > 0 else current >= self.end: yield current current += self.step @classmethod - def from_strings(cls, values: Iterable[str]) -> "list[DateRange | QEnum]": + def from_strings(cls, values: Iterable[str]) -> Sequence[DateRange | DateEnum]: dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values]) if len(dates) < 2: return [DateEnum(dates)] - ranges = [] + ranges: list[DateEnum | DateRange] = [] current_group, dates = ( [ dates[0], @@ -243,7 +261,7 @@ class TimeRange(Range): def min(self): return self.start - def __iter__(self) -> Iterable[Any]: + def __iter__(self) -> Iterator[Any]: return super().__iter__() @classmethod @@ -369,7 +387,7 @@ def values_from_json(obj) -> ValueGroup: def convert_datatypes(q: "Qube", conversions: dict[str, ValueGroup]) -> "Qube": - def _convert(q: "Qube") -> Iterable["Qube"]: + def _convert(q: "Qube") -> Iterator["Qube"]: if q.key in conversions: data_type = conversions[q.key] assert isinstance(q.values, QEnum), (