Fix all of mypy's complaints.

This commit is contained in:
Tom 2025-04-23 12:43:49 +01:00
parent 10106ba6d8
commit 7b36a76154
7 changed files with 158 additions and 117 deletions

View File

@ -1,3 +1,8 @@
# This causes python types to be evaluated later,
# allowing you to reference types like Qube inside the definion of the Qube class
# without having to do "Qube"
from __future__ import annotations
import dataclasses
import functools
import json
@ -6,8 +11,9 @@ from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any, Iterable, Iterator, Literal, Sequence
from typing import Any, Iterable, Iterator, Literal, Mapping, Sequence
import numpy as np
from frozendict import frozendict
from . import set_operations
@ -17,13 +23,18 @@ from .tree_formatters import (
node_tree_to_html,
node_tree_to_string,
)
from .value_types import QEnum, ValueGroup, WildcardGroup, values_from_json
from .value_types import (
QEnum,
ValueGroup,
WildcardGroup,
values_from_json,
)
@dataclass(frozen=False, eq=True, order=True, unsafe_hash=True)
class Qube:
data: NodeData
children: tuple["Qube", ...]
children: tuple[Qube, ...]
@property
def key(self) -> str:
@ -34,10 +45,10 @@ class Qube:
return self.data.values
@property
def metadata(self) -> frozendict[str, Any]:
def metadata(self) -> Mapping[str, np.ndarray]:
return self.data.metadata
def replace(self, **kwargs) -> "Qube":
def replace(self, **kwargs) -> Qube:
data_keys = {
k: v for k, v in kwargs.items() if k in ["key", "values", "metadata"]
}
@ -55,41 +66,41 @@ class Qube:
return self.data.summary()
@classmethod
def make(cls, key: str, values: ValueGroup, children, **kwargs) -> "Qube":
def make(cls, key: str, values: ValueGroup, children, **kwargs) -> Qube:
return cls(
data=NodeData(key, values, metadata=kwargs.get("metadata", frozendict())),
children=tuple(sorted(children, key=lambda n: ((n.key, n.values.min())))),
)
@classmethod
def root_node(cls, children: Iterable["Qube"]) -> "Qube":
def root_node(cls, children: Iterable[Qube]) -> Qube:
return cls.make("root", QEnum(("root",)), children)
@classmethod
def load(cls, path: str | Path) -> "Qube":
def load(cls, path: str | Path) -> Qube:
with open(path, "r") as f:
return Qube.from_json(json.load(f))
@classmethod
def from_datacube(cls, datacube: dict[str, str | Sequence[str]]) -> "Qube":
def from_datacube(cls, datacube: dict[str, str | Sequence[str]]) -> Qube:
key_vals = list(datacube.items())[::-1]
children: list["Qube"] = []
children: list[Qube] = []
for key, values in key_vals:
values_group: ValueGroup
if values == "*":
values = WildcardGroup()
elif not isinstance(values, list):
values = [values]
values_group = WildcardGroup()
elif isinstance(values, list):
values_group = QEnum(values)
else:
values_group = QEnum([values])
if isinstance(values, list):
values = QEnum(values)
children = [cls.make(key, values, children)]
children = [cls.make(key, values_group, children)]
return cls.root_node(children)
@classmethod
def from_json(cls, json: dict) -> "Qube":
def from_json(cls, json: dict) -> Qube:
def from_json(json: dict) -> Qube:
return Qube.make(
key=json["key"],
@ -112,7 +123,7 @@ class Qube:
return to_json(self)
@classmethod
def from_dict(cls, d: dict) -> "Qube":
def from_dict(cls, d: dict) -> Qube:
def from_dict(d: dict) -> Iterator[Qube]:
for k, children in d.items():
key, values = k.split("=")
@ -131,8 +142,8 @@ class Qube:
return Qube.root_node(list(from_dict(d)))
def to_dict(self) -> dict:
def to_dict(q: "Qube") -> tuple[str, dict]:
key = f"{q.key}={','.join(str(v) for v in q.values.values)}"
def to_dict(q: Qube) -> tuple[str, dict]:
key = f"{q.key}={','.join(str(v) for v in q.values)}"
return key, dict(to_dict(c) for c in q.children)
return to_dict(self)[1]
@ -189,10 +200,10 @@ class Qube:
return cls.from_dict(root)
@classmethod
def empty(cls) -> "Qube":
def empty(cls) -> Qube:
return Qube.root_node([])
def __str__(self, depth=None, name=None) -> str:
def __str_helper__(self, depth=None, name=None) -> str:
node = (
dataclasses.replace(
self,
@ -203,8 +214,11 @@ class Qube:
)
return "".join(node_tree_to_string(node=node, depth=depth))
def __str__(self):
return self.__str_helper__()
def print(self, depth=None, name: str | None = None):
print(self.__str__(depth=depth, name=name))
print(self.__str_helper__(depth=depth, name=name))
def html(self, depth=2, collapse=True, name: str | None = None) -> HTML:
node = (
@ -221,27 +235,27 @@ class Qube:
return node_tree_to_html(self, depth=2, collapse=True)
# Allow "key=value/value" / qube to prepend keys
def __rtruediv__(self, other: str) -> "Qube":
def __rtruediv__(self, other: str) -> Qube:
key, values = other.split("=")
values = QEnum((values.split("/")))
return Qube.root_node([Qube.make(key, values, self.children)])
values_enum = QEnum((values.split("/")))
return Qube.root_node([Qube.make(key, values_enum, self.children)])
def __or__(self, other: "Qube") -> "Qube":
def __or__(self, other: Qube) -> Qube:
return set_operations.operation(
self, other, set_operations.SetOperation.UNION, type(self)
)
def __and__(self, other: "Qube") -> "Qube":
def __and__(self, other: Qube) -> Qube:
return set_operations.operation(
self, other, set_operations.SetOperation.INTERSECTION, type(self)
)
def __sub__(self, other: "Qube") -> "Qube":
def __sub__(self, other: Qube) -> Qube:
return set_operations.operation(
self, other, set_operations.SetOperation.DIFFERENCE, type(self)
)
def __xor__(self, other: "Qube") -> "Qube":
def __xor__(self, other: Qube) -> Qube:
return set_operations.operation(
self, other, set_operations.SetOperation.SYMMETRIC_DIFFERENCE, type(self)
)
@ -270,11 +284,10 @@ class Qube:
def leaves_with_metadata(
self, indices=()
) -> Iterable[tuple[dict[str, str], dict[str, str]]]:
) -> Iterator[tuple[dict[str, str], dict[str, str | np.ndarray]]]:
if self.key == "root":
for c in self.children:
for leaf in c.leaves_with_metadata(indices=()):
yield leaf
yield from c.leaves_with_metadata(indices=())
return
for index, value in enumerate(self.values):
@ -305,21 +318,21 @@ class Qube:
yield from to_list_of_cubes(c)
if not node.children:
yield {node.key: list(node.values.values)}
yield {node.key: list(node.values)}
for c in node.children:
for sub_cube in to_list_of_cubes(c):
yield {node.key: list(node.values.values)} | sub_cube
yield {node.key: list(node.values)} | sub_cube
return to_list_of_cubes(self)
def __getitem__(self, args) -> "Qube":
def __getitem__(self, args) -> Qube:
if isinstance(args, str):
specifiers = args.split(",")
current = self
for specifier in specifiers:
key, values = specifier.split("=")
values = values.split("/")
key, values_str = specifier.split("=")
values = values_str.split("/")
for c in current.children:
if c.key == key and set(values) == set(c.values):
current = c
@ -354,7 +367,7 @@ class Qube:
return 0
return 1 + sum(c.n_nodes for c in self.children)
def transform(self, func: "Callable[[Qube], Qube | Iterable[Qube]]") -> "Qube":
def transform(self, func: "Callable[[Qube], Qube | Iterable[Qube]]") -> Qube:
"""
Call a function on every node of the Qube, return one or more nodes.
If multiple nodes are returned they each get a copy of the (transformed) children of the original node.
@ -375,8 +388,8 @@ class Qube:
def remove_by_key(self, keys: str | list[str]):
_keys: list[str] = keys if isinstance(keys, list) else [keys]
def remove_key(node: "Qube") -> "Qube":
children = []
def remove_key(node: Qube) -> Qube:
children: list[Qube] = []
for c in node.children:
if c.key in _keys:
grandchildren = tuple(sorted(remove_key(cc) for cc in c.children))
@ -405,17 +418,23 @@ class Qube:
mode: Literal["strict", "relaxed"] = "relaxed",
prune=True,
consume=False,
) -> "Qube":
# make all values lists
selection: dict[str, list[str] | Callable[[Any], bool]] = {
k: v if isinstance(v, list | Callable) else [v]
for k, v in selection.items()
}
) -> Qube:
# Find any bare str values and replace them with [str]
_selection: dict[str, list[str] | Callable[[Any], bool]] = {}
for k, v in selection.items():
if isinstance(v, list):
_selection[k] = v
elif callable(v):
_selection[k] = v
else:
_selection[k] = [v]
def not_none(xs):
return tuple(x for x in xs if x is not None)
def select(node: Qube, selection: dict[str, list[str]]) -> Qube | None:
def select(
node: Qube, selection: dict[str, list[str] | Callable[[Any], bool]]
) -> Qube | None:
# If this node has no children but there are still parts of the request
# that have not been consumed, then prune this whole branch
if consume and not node.children and selection:
@ -442,13 +461,15 @@ class Qube:
raise ValueError(f"Unknown mode argument {mode}")
# If the key IS in the selection then check if the values match
if node.key in selection:
if node.key in _selection:
# If the key is specified, check if any of the values match
selection_criteria = selection[node.key]
if isinstance(selection_criteria, Callable):
selection_criteria = _selection[node.key]
if callable(selection_criteria):
values = QEnum((c for c in node.values if selection_criteria(c)))
elif isinstance(selection_criteria, list):
values = QEnum((c for c in selection_criteria if c in node.values))
else:
values = QEnum((c for c in selection[node.key] if c in node.values))
raise ValueError(f"Unknown selection type {selection_criteria}")
# Here modes don't matter because we've explicitly filtered on this key and found nothing
if not values:
@ -468,11 +489,11 @@ class Qube:
return node.replace(
children=new_children,
metadata=self.metadata | {"is_leaf": not bool(new_children)},
metadata=dict(self.metadata) | {"is_leaf": not bool(new_children)},
)
return self.replace(
children=not_none(select(c, selection) for c in self.children)
children=not_none(select(c, _selection) for c in self.children)
)
def span(self, key: str) -> list[str]:
@ -508,7 +529,7 @@ class Qube:
return hash_node(self)
def compress(self) -> "Qube":
def compress(self) -> Qube:
"""
This method is quite computationally heavy because of trees like this:
root, class=d1, generation=1
@ -519,7 +540,7 @@ class Qube:
"""
def union(a: "Qube", b: "Qube") -> "Qube":
def union(a: Qube, b: Qube) -> Qube:
b = type(self).root_node(children=(b,))
out = set_operations.operation(
a, b, set_operations.SetOperation.UNION, type(self)
@ -528,6 +549,8 @@ class Qube:
new_children = [c.compress() for c in self.children]
if len(new_children) > 1:
new_children = functools.reduce(union, new_children, Qube.empty()).children
new_children = list(
functools.reduce(union, new_children, Qube.empty()).children
)
return self.replace(children=tuple(sorted(new_children)))

View File

@ -8,8 +8,10 @@ def parse_key_value_pairs(text: str):
for segment in text.split(","):
if "=" not in segment:
print(segment)
key, values = segment.split("=", 1) # Ensure split only happens at first "="
values = values.split("/")
key, values_str = segment.split(
"=", 1
) # Ensure split only happens at first "="
values = values_str.split("/")
result[key] = values
return result

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Hashable
from typing import Mapping
import numpy as np
from frozendict import frozendict
from .value_types import ValueGroup
@ -10,7 +11,7 @@ from .value_types import ValueGroup
class NodeData:
key: str
values: ValueGroup
metadata: dict[str, tuple[Hashable, ...]] = field(
metadata: Mapping[str, np.ndarray] = field(
default_factory=frozendict, compare=False
)

View File

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import replace
from enum import Enum
@ -11,7 +13,7 @@ from .node_types import NodeData
from .value_types import QEnum, ValueGroup, WildcardGroup
if TYPE_CHECKING:
from .qube import Qube
from .Qube import Qube
class SetOperation(Enum):
@ -22,7 +24,7 @@ class SetOperation(Enum):
def node_intersection(
A: "ValueGroup", B: "ValueGroup"
A: ValueGroup, B: ValueGroup
) -> tuple[ValueGroup, ValueGroup, ValueGroup]:
if isinstance(A, QEnum) and isinstance(B, QEnum):
set_A, set_B = set(A), set(B)
@ -49,7 +51,7 @@ def node_intersection(
)
def operation(A: "Qube", B: "Qube", operation_type: SetOperation, node_type) -> "Qube":
def operation(A: Qube, B: Qube, operation_type: SetOperation, node_type) -> Qube:
assert A.key == B.key, (
"The two Qube root nodes must have the same key to perform set operations,"
f"would usually be two root nodes. They have {A.key} and {B.key} respectively"
@ -60,13 +62,15 @@ def operation(A: "Qube", B: "Qube", operation_type: SetOperation, node_type) ->
)
# Group the children of the two nodes by key
nodes_by_key = defaultdict(lambda: ([], []))
nodes_by_key: defaultdict[str, tuple[list[Qube], list[Qube]]] = defaultdict(
lambda: ([], [])
)
for node in A.children:
nodes_by_key[node.key][0].append(node)
for node in B.children:
nodes_by_key[node.key][1].append(node)
new_children = []
new_children: list[Qube] = []
# For every node group, perform the set operation
for key, (A_nodes, B_nodes) in nodes_by_key.items():
@ -77,16 +81,16 @@ def operation(A: "Qube", B: "Qube", operation_type: SetOperation, node_type) ->
# Whenever we modify children we should recompress them
# But since `operation` is already recursive, we only need to compress this level not all levels
# Hence we use the non-recursive _compress method
new_children = compress_children(new_children)
new_children = list(compress_children(new_children))
# The values and key are the same so we just replace the children
return replace(A, children=new_children)
return A.replace(children=tuple(sorted(new_children)))
# The root node is special so we need a helper method that we can recurse on
def _operation(
key: str, A: list["Qube"], B: list["Qube"], operation_type: SetOperation, node_type
) -> Iterable["Qube"]:
key: str, A: list[Qube], B: list[Qube], operation_type: SetOperation, node_type
) -> Iterable[Qube]:
keep_just_A, keep_intersection, keep_just_B = operation_type.value
# Iterate over all pairs (node_A, node_B)
@ -128,7 +132,7 @@ def _operation(
yield node_type.make(key, values[node], node.children)
def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]:
def compress_children(children: Iterable[Qube]) -> tuple[Qube, ...]:
"""
Helper method tht only compresses a set of nodes, and doesn't do it recursively.
Used in Qubed.compress but also to maintain compression in the set operations above.
@ -138,16 +142,16 @@ def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]:
identical_children = defaultdict(set)
for child in children:
# only care about the key and children of each node, ignore values
key = hash((child.key, tuple((cc.structural_hash for cc in child.children))))
identical_children[key].add(child)
h = hash((child.key, tuple((cc.structural_hash for cc in child.children))))
identical_children[h].add(child)
# Now go through and create new compressed nodes for any groups that need collapsing
new_children = []
for child_set in identical_children.values():
if len(child_set) > 1:
child_set = list(child_set)
node_type = type(child_set[0])
key = child_set[0].key
child_list = list(child_set)
node_type = type(child_list[0])
key = child_list[0].key
# Compress the children into a single node
assert all(isinstance(child.data.values, QEnum) for child in child_set), (
@ -155,13 +159,11 @@ def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]:
)
node_data = NodeData(
key=key,
key=str(key),
metadata=frozendict(), # Todo: Implement metadata compression
values=QEnum(
(v for child in child_set for v in child.data.values.values)
),
values=QEnum((v for child in child_set for v in child.data.values)),
)
new_child = node_type(data=node_data, children=child_set[0].children)
new_child = node_type(data=node_data, children=child_list[0].children)
else:
# If the group is size one just keep it
new_child = child_set.pop()
@ -170,7 +172,7 @@ def compress_children(children: Iterable["Qube"]) -> tuple["Qube"]:
return tuple(sorted(new_children, key=lambda n: ((n.key, n.values.min()))))
def union(a: "Qube", b: "Qube") -> "Qube":
def union(a: Qube, b: Qube) -> Qube:
return operation(
a,
b,

View File

@ -1,16 +1,11 @@
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Iterable, Protocol, Sequence, runtime_checkable
from typing import TYPE_CHECKING, Iterable
@runtime_checkable
class TreeLike(Protocol):
@property
def children(
self,
) -> Sequence["TreeLike"]: ... # Supports indexing like node.children[i]
def summary(self) -> str: ...
if TYPE_CHECKING:
from .Qube import Qube
@dataclass(frozen=True)
@ -22,8 +17,8 @@ class HTML:
def summarize_node(
node: TreeLike, collapse=False, max_summary_length=50, **kwargs
) -> tuple[str, str, TreeLike]:
node: Qube, collapse=False, max_summary_length=50, **kwargs
) -> tuple[str, str, Qube]:
"""
Extracts a summarized representation of the node while collapsing single-child paths.
Returns the summary string and the last node in the chain that has multiple children.
@ -50,7 +45,7 @@ def summarize_node(
return ", ".join(summaries), ",".join(paths), node
def node_tree_to_string(node: TreeLike, prefix: str = "", depth=None) -> Iterable[str]:
def node_tree_to_string(node: Qube, prefix: str = "", depth=None) -> Iterable[str]:
summary, path, node = summarize_node(node)
if depth is not None and depth <= 0:
@ -74,7 +69,7 @@ def node_tree_to_string(node: TreeLike, prefix: str = "", depth=None) -> Iterabl
def _node_tree_to_html(
node: TreeLike, prefix: str = "", depth=1, connector="", **kwargs
node: Qube, prefix: str = "", depth=1, connector="", **kwargs
) -> Iterable[str]:
summary, path, node = summarize_node(node, **kwargs)
@ -99,7 +94,7 @@ def _node_tree_to_html(
def node_tree_to_html(
node: TreeLike, depth=1, include_css=True, include_js=True, css_id=None, **kwargs
node: Qube, depth=1, include_css=True, include_js=True, css_id=None, **kwargs
) -> str:
if css_id is None:
css_id = f"qubed-tree-{random.randint(0, 1000000)}"

View File

@ -1,8 +1,19 @@
from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Any, FrozenSet, Iterable, Literal, TypeVar
from typing import (
TYPE_CHECKING,
Any,
FrozenSet,
Iterable,
Iterator,
Literal,
Sequence,
TypeVar,
)
if TYPE_CHECKING:
from .Qube import Qube
@ -30,23 +41,19 @@ class ValueGroup(ABC):
"Return the minimum value in the group."
pass
@dataclass(frozen=True)
class FiniteValueGroup(ValueGroup, ABC):
@classmethod
@abstractmethod
def __len__(self) -> int:
"Return how many values this group contains."
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
"Given a list of strings, return a one or more ValueGroups of this type."
pass
@abstractmethod
def __iter__(self) -> Iterable[Any]:
def __iter__(self) -> Iterator:
"Iterate over the values in the group."
pass
@classmethod
@abstractmethod
def from_strings(cls, values: Iterable[str]) -> list["ValueGroup"]:
"Given a list of strings, return a one or more ValueGroups of this type."
def __len__(self) -> int:
pass
@ -55,7 +62,7 @@ EnumValuesType = FrozenSet[T]
@dataclass(frozen=True, order=True)
class QEnum(FiniteValueGroup):
class QEnum(ValueGroup):
"""
The simplest kind of key value is just a list of strings.
summary -> string1/string2/string....
@ -81,8 +88,9 @@ class QEnum(FiniteValueGroup):
def __contains__(self, value: Any) -> bool:
return value in self.values
def from_strings(self, values: Iterable[str]) -> list["ValueGroup"]:
return [type(self)(tuple(values))]
@classmethod
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
return [cls(tuple(values))]
def min(self):
return min(self.values)
@ -105,6 +113,16 @@ class WildcardGroup(ValueGroup):
def min(self):
return "*"
def __len__(self):
return None
def __iter__(self):
return ["*"]
@classmethod
def from_strings(cls, values: Iterable[str]) -> Sequence[ValueGroup]:
return [WildcardGroup()]
class DateEnum(QEnum):
def summary(self) -> str:
@ -125,7 +143,7 @@ class Range(ValueGroup, ABC):
def min(self):
return self.start
def __iter__(self) -> Iterable[Any]:
def __iter__(self) -> Iterator[Any]:
i = self.start
while i <= self.end:
yield i
@ -145,19 +163,19 @@ class DateRange(Range):
def __len__(self) -> int:
return (self.end - self.start) // self.step
def __iter__(self) -> Iterable[date]:
def __iter__(self) -> Iterator[date]:
current = self.start
while current <= self.end if self.step.days > 0 else current >= self.end:
yield current
current += self.step
@classmethod
def from_strings(cls, values: Iterable[str]) -> "list[DateRange | QEnum]":
def from_strings(cls, values: Iterable[str]) -> Sequence[DateRange | DateEnum]:
dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values])
if len(dates) < 2:
return [DateEnum(dates)]
ranges = []
ranges: list[DateEnum | DateRange] = []
current_group, dates = (
[
dates[0],
@ -243,7 +261,7 @@ class TimeRange(Range):
def min(self):
return self.start
def __iter__(self) -> Iterable[Any]:
def __iter__(self) -> Iterator[Any]:
return super().__iter__()
@classmethod
@ -369,7 +387,7 @@ def values_from_json(obj) -> ValueGroup:
def convert_datatypes(q: "Qube", conversions: dict[str, ValueGroup]) -> "Qube":
def _convert(q: "Qube") -> Iterable["Qube"]:
def _convert(q: "Qube") -> Iterator["Qube"]:
if q.key in conversions:
data_type = conversions[q.key]
assert isinstance(q.values, QEnum), (