473 lines
14 KiB
Python
473 lines
14 KiB
Python
import dataclasses
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from datetime import date, datetime, timedelta
|
|
from typing import Any, Callable, Iterable, Literal
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class HTML:
|
|
html: str
|
|
|
|
def _repr_html_(self):
|
|
return self.html
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Values(ABC):
|
|
@abstractmethod
|
|
def summary(self) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def __len__(self) -> int:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def __contains__(self, value: Any) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def from_strings(self, values: list[str]) -> list["Values"]:
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Enum(Values):
|
|
"""
|
|
The simplest kind of key value is just a list of strings.
|
|
summary -> string1/string2/string....
|
|
"""
|
|
|
|
values: list[Any]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.values)
|
|
|
|
def summary(self) -> str:
|
|
return "/".join(sorted(self.values))
|
|
|
|
def __contains__(self, value: Any) -> bool:
|
|
return value in self.values
|
|
|
|
def from_strings(self, values: list[str]) -> list["Values"]:
|
|
return [Enum(values)]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Range(Values, ABC):
|
|
dtype: str = dataclasses.field(kw_only=True)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DateRange(Range):
|
|
start: date
|
|
end: date
|
|
step: timedelta
|
|
dtype: Literal["date"] = dataclasses.field(kw_only=True, default="date")
|
|
|
|
@classmethod
|
|
def from_strings(self, values: list[str]) -> list["DateRange"]:
|
|
dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values])
|
|
if len(dates) < 2:
|
|
return [DateRange(start=dates[0], end=dates[0], step=timedelta(days=0))]
|
|
|
|
ranges = []
|
|
current_range, dates = (
|
|
[
|
|
dates[0],
|
|
],
|
|
dates[1:],
|
|
)
|
|
while len(dates) > 1:
|
|
if dates[0] - current_range[-1] == timedelta(days=1):
|
|
current_range.append(dates.pop(0))
|
|
|
|
elif len(current_range) == 1:
|
|
ranges.append(
|
|
DateRange(
|
|
start=current_range[0],
|
|
end=current_range[0],
|
|
step=timedelta(days=0),
|
|
)
|
|
)
|
|
current_range = [
|
|
dates.pop(0),
|
|
]
|
|
|
|
else:
|
|
ranges.append(
|
|
DateRange(
|
|
start=current_range[0],
|
|
end=current_range[-1],
|
|
step=timedelta(days=1),
|
|
)
|
|
)
|
|
current_range = [
|
|
dates.pop(0),
|
|
]
|
|
return ranges
|
|
|
|
def __contains__(self, value: Any) -> bool:
|
|
v = datetime.strptime(value, "%Y%m%d").date()
|
|
return self.start <= v <= self.end and (v - self.start) % self.step == 0
|
|
|
|
def __len__(self) -> int:
|
|
return (self.end - self.start) // self.step
|
|
|
|
def summary(self) -> str:
|
|
def fmt(d):
|
|
return d.strftime("%Y%m%d")
|
|
|
|
if self.step == timedelta(days=0):
|
|
return f"{fmt(self.start)}"
|
|
if self.step == timedelta(days=1):
|
|
return f"{fmt(self.start)}/to/{fmt(self.end)}"
|
|
|
|
return (
|
|
f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step // timedelta(days=1)}"
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TimeRange(Range):
|
|
start: int
|
|
end: int
|
|
step: int
|
|
dtype: Literal["time"] = dataclasses.field(kw_only=True, default="time")
|
|
|
|
@classmethod
|
|
def from_strings(self, values: list[str]) -> list["TimeRange"]:
|
|
if len(values) == 0:
|
|
return []
|
|
|
|
times = sorted([int(v) for v in values])
|
|
if len(times) < 2:
|
|
return [TimeRange(start=times[0], end=times[0], step=100)]
|
|
|
|
ranges = []
|
|
current_range, times = (
|
|
[
|
|
times[0],
|
|
],
|
|
times[1:],
|
|
)
|
|
while len(times) > 1:
|
|
if times[0] - current_range[-1] == 1:
|
|
current_range.append(times.pop(0))
|
|
|
|
elif len(current_range) == 1:
|
|
ranges.append(
|
|
TimeRange(start=current_range[0], end=current_range[0], step=0)
|
|
)
|
|
current_range = [
|
|
times.pop(0),
|
|
]
|
|
|
|
else:
|
|
ranges.append(
|
|
TimeRange(start=current_range[0], end=current_range[-1], step=1)
|
|
)
|
|
current_range = [
|
|
times.pop(0),
|
|
]
|
|
return ranges
|
|
|
|
def __len__(self) -> int:
|
|
return (self.end - self.start) // self.step
|
|
|
|
def summary(self) -> str:
|
|
def fmt(d):
|
|
return f"{d:04d}"
|
|
|
|
if self.step == 0:
|
|
return f"{fmt(self.start)}"
|
|
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}"
|
|
|
|
def __contains__(self, value: Any) -> bool:
|
|
v = int(value)
|
|
return self.start <= v <= self.end and (v - self.start) % self.step == 0
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class IntRange(Range):
|
|
dtype: Literal["int"]
|
|
start: int
|
|
end: int
|
|
step: int
|
|
dtype: Literal["int"] = dataclasses.field(kw_only=True, default="int")
|
|
|
|
def __len__(self) -> int:
|
|
return (self.end - self.start) // self.step
|
|
|
|
def summary(self) -> str:
|
|
def fmt(d):
|
|
return d.strftime("%Y%m%d")
|
|
|
|
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}"
|
|
|
|
def __contains__(self, value: Any) -> bool:
|
|
v = int(value)
|
|
return self.start <= v <= self.end and (v - self.start) % self.step == 0
|
|
|
|
|
|
def values_from_json(obj) -> Values:
|
|
if isinstance(obj, list):
|
|
return Enum(obj)
|
|
|
|
match obj["dtype"]:
|
|
case "date":
|
|
return DateRange(**obj)
|
|
case "time":
|
|
return TimeRange(**obj)
|
|
case "int":
|
|
return IntRange(**obj)
|
|
case _:
|
|
raise ValueError(f"Unknown dtype {obj['dtype']}")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Node:
|
|
key: str
|
|
values: Values # Must support len()
|
|
metadata: dict[str, str] # Applies to all children
|
|
payload: list[Any] # List of size product(len(n.values) for n in ancestors(self))
|
|
children: list["Node"]
|
|
|
|
|
|
def summarize_node(node: Node) -> tuple[str, Node]:
|
|
"""
|
|
Extracts a summarized representation of the node while collapsing single-child paths.
|
|
Returns the summary string and the last node in the chain that has multiple children.
|
|
"""
|
|
summary = []
|
|
|
|
while True:
|
|
values_summary = node.values.summary()
|
|
if len(values_summary) > 50:
|
|
values_summary = values_summary[:50] + "..."
|
|
summary.append(f"{node.key}={values_summary}")
|
|
|
|
# Move down if there's exactly one child, otherwise stop
|
|
if len(node.children) != 1:
|
|
break
|
|
node = node.children[0]
|
|
|
|
return ", ".join(summary), node
|
|
|
|
|
|
def node_tree_to_string(node: Node, prefix: str = "", depth=None) -> Iterable[str]:
|
|
summary, node = summarize_node(node)
|
|
|
|
if depth is not None and depth <= 0:
|
|
yield summary + " - ...\n"
|
|
return
|
|
# Special case for nodes with only a single child, this makes the printed representation more compact
|
|
elif len(node.children) == 1:
|
|
yield summary + ", "
|
|
yield from node_tree_to_string(node.children[0], prefix, depth=depth)
|
|
return
|
|
else:
|
|
yield summary + "\n"
|
|
|
|
for index, child in enumerate(node.children):
|
|
connector = "└── " if index == len(node.children) - 1 else "├── "
|
|
yield prefix + connector
|
|
extension = " " if index == len(node.children) - 1 else "│ "
|
|
yield from node_tree_to_string(
|
|
child, prefix + extension, depth=depth - 1 if depth is not None else None
|
|
)
|
|
|
|
|
|
def node_tree_to_html(
|
|
node: Node, prefix: str = "", depth=1, connector=""
|
|
) -> Iterable[str]:
|
|
summary, node = summarize_node(node)
|
|
|
|
if len(node.children) == 0:
|
|
yield f'<span class="leaf">{connector}{summary}</span>'
|
|
return
|
|
else:
|
|
open = "open" if depth > 0 else ""
|
|
yield f"<details {open}><summary>{connector}{summary}</summary>"
|
|
|
|
for index, child in enumerate(node.children):
|
|
connector = "└── " if index == len(node.children) - 1 else "├── "
|
|
extension = " " if index == len(node.children) - 1 else "│ "
|
|
yield from node_tree_to_html(
|
|
child, prefix + extension, depth=depth - 1, connector=prefix + connector
|
|
)
|
|
yield "</details>"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CompressedTree:
|
|
root: Node
|
|
|
|
@classmethod
|
|
def from_json(cls, json: dict) -> "CompressedTree":
|
|
def from_json(json: dict) -> Node:
|
|
return Node(
|
|
key=json["key"],
|
|
values=values_from_json(json["values"]),
|
|
metadata=json["metadata"] if "metadata" in json else {},
|
|
payload=json["payload"] if "payload" in json else [],
|
|
children=[from_json(c) for c in json["children"]],
|
|
)
|
|
|
|
return CompressedTree(root=from_json(json))
|
|
|
|
def __str__(self):
|
|
return "".join(node_tree_to_string(node=self.root))
|
|
|
|
def html(self, depth=2) -> HTML:
|
|
return HTML(self._repr_html_(depth=depth))
|
|
|
|
def _repr_html_(self, depth=2):
|
|
css = """
|
|
<style>
|
|
.qubed-tree-view {
|
|
font-family: monospace;
|
|
white-space: pre;
|
|
}
|
|
.qubed-tree-view details {
|
|
# display: inline;
|
|
margin-left: 0;
|
|
}
|
|
.qubed-tree-view summary {
|
|
list-style: none;
|
|
cursor: pointer;
|
|
text-overflow: ellipsis;
|
|
overflow: hidden;
|
|
text-wrap: nowrap;
|
|
display: block;
|
|
}
|
|
|
|
.qubed-tree-view .leaf {
|
|
text-overflow: ellipsis;
|
|
overflow: hidden;
|
|
text-wrap: nowrap;
|
|
display: block;
|
|
}
|
|
|
|
.qubed-tree-view summary:hover,span.leaf:hover {
|
|
background-color: #f0f0f0;
|
|
}
|
|
.qubed-tree-view details > summary::after {
|
|
content: ' ';
|
|
}
|
|
.qubed-tree-view details:not([open]) > summary::after {
|
|
content: " ▼";
|
|
}
|
|
</style>
|
|
|
|
"""
|
|
nodes = "".join(
|
|
cc
|
|
for c in self.root.children
|
|
for cc in node_tree_to_html(node=c, depth=depth)
|
|
)
|
|
return f"{css}<pre class='qubed-tree-view'>{nodes}</pre>"
|
|
|
|
def print(self, depth=None):
|
|
print(
|
|
"".join(
|
|
cc
|
|
for c in self.root.children
|
|
for cc in node_tree_to_string(node=c, depth=depth)
|
|
)
|
|
)
|
|
|
|
def transform(self, func: Callable[[Node], Node]) -> "CompressedTree":
|
|
"Call a function on every node of the tree, any changes to the children of a node will be ignored."
|
|
|
|
def transform(node: Node) -> Node:
|
|
new_node = func(node)
|
|
return dataclasses.replace(
|
|
new_node, children=[transform(c) for c in node.children]
|
|
)
|
|
|
|
return CompressedTree(root=transform(self.root))
|
|
|
|
def guess_datatypes(self) -> "CompressedTree":
|
|
def guess_datatypes(node: Node) -> list[Node]:
|
|
# Try to convert enum values into more structured types
|
|
children = [cc for c in node.children for cc in guess_datatypes(c)]
|
|
|
|
if isinstance(node.values, Enum):
|
|
match node.key:
|
|
case "time":
|
|
range_class = TimeRange
|
|
case "date":
|
|
range_class = DateRange
|
|
case _:
|
|
range_class = None
|
|
|
|
if range_class is not None:
|
|
return [
|
|
dataclasses.replace(node, values=range, children=children)
|
|
for range in range_class.from_strings(node.values.values)
|
|
]
|
|
return [dataclasses.replace(node, children=children)]
|
|
|
|
children = [cc for c in self.root.children for cc in guess_datatypes(c)]
|
|
return CompressedTree(root=dataclasses.replace(self.root, children=children))
|
|
|
|
def select(
|
|
self,
|
|
selection: dict[str, str | list[str]],
|
|
mode: Literal["strict", "relaxed"] = "relaxed",
|
|
) -> "CompressedTree":
|
|
# make all values lists
|
|
selection = {k: v if isinstance(v, list) else [v] for k, v in selection.items()}
|
|
|
|
def not_none(xs):
|
|
return [x for x in xs if x is not None]
|
|
|
|
def select(node: Node) -> Node | None:
|
|
# Check if the key is specified in the selection
|
|
if node.key not in selection:
|
|
if mode == "strict":
|
|
return None
|
|
return dataclasses.replace(
|
|
node, children=not_none(select(c) for c in node.children)
|
|
)
|
|
|
|
# If the key is specified, check if any of the values match
|
|
values = Enum([c for c in selection[node.key] if c in node.values])
|
|
|
|
if not values:
|
|
return None
|
|
|
|
return dataclasses.replace(
|
|
node, values=values, children=not_none(select(c) for c in node.children)
|
|
)
|
|
|
|
return CompressedTree(
|
|
root=dataclasses.replace(
|
|
self.root, children=not_none(select(c) for c in self.root.children)
|
|
)
|
|
)
|
|
|
|
def to_list_of_cubes(self):
|
|
def to_list_of_cubes(node: Node) -> list[list[Node]]:
|
|
return [
|
|
[node] + sub_cube
|
|
for c in node.children
|
|
for sub_cube in to_list_of_cubes(c)
|
|
]
|
|
|
|
return to_list_of_cubes(self.root)
|
|
|
|
def info(self):
|
|
cubes = self.to_list_of_cubes()
|
|
print(f"Number of distinct paths: {len(cubes)}")
|
|
|
|
|
|
# What should the interace look like?
|
|
|
|
# tree = CompressedTree.from_json(...)
|
|
# tree = CompressedTree.from_protobuf(...)
|
|
|
|
# tree.print(depth = 5) # Prints a nice tree representation
|