qubed/notebooks/DataCubeTree.py
2025-02-06 13:18:31 +00:00

395 lines
13 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import dataclasses
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import date, datetime, timedelta
from typing import Any, Callable, Iterable, Literal
@dataclass(frozen=True)
class HTML():
html: str
def _repr_html_(self):
return self.html
@dataclass(frozen=True)
class Values(ABC):
@abstractmethod
def summary(self) -> str:
pass
@abstractmethod
def __len__(self) -> int:
pass
@abstractmethod
def __contains__(self, value: Any) -> bool:
pass
@abstractmethod
def from_strings(self, values: list[str]) -> list['Values']:
pass
@dataclass(frozen=True)
class Enum(Values):
"""
The simplest kind of key value is just a list of strings.
summary -> string1/string2/string....
"""
values: list[Any]
def __len__(self) -> int:
return len(self.values)
def summary(self) -> str:
return '/'.join(sorted(self.values))
def __contains__(self, value: Any) -> bool:
return value in self.values
def from_strings(self, values: list[str]) -> list['Values']:
return [Enum(values)]
@dataclass(frozen=True)
class Range(Values, ABC):
dtype: str = dataclasses.field(kw_only=True)
@dataclass(frozen=True)
class DateRange(Range):
start: date
end: date
step: timedelta
dtype: Literal["date"] = dataclasses.field(kw_only=True, default="date")
@classmethod
def from_strings(self, values: list[str]) -> list['DateRange']:
dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values])
if len(dates) < 2:
return [DateRange(
start=dates[0],
end=dates[0],
step=timedelta(days=0)
)]
ranges = []
current_range, dates = [dates[0],], dates[1:]
while len(dates) > 1:
if dates[0] - current_range[-1] == timedelta(days=1):
current_range.append(dates.pop(0))
elif len(current_range) == 1:
ranges.append(DateRange(
start=current_range[0],
end=current_range[0],
step=timedelta(days=0)
))
current_range = [dates.pop(0),]
else:
ranges.append(DateRange(
start=current_range[0],
end=current_range[-1],
step=timedelta(days=1)
))
current_range = [dates.pop(0),]
return ranges
def __contains__(self, value: Any) -> bool:
v = datetime.strptime(value, "%Y%m%d").date()
return self.start <= v <= self.end and (v - self.start) % self.step == 0
def __len__(self) -> int:
return (self.end - self.start) // self.step
def summary(self) -> str:
def fmt(d): return d.strftime("%Y%m%d")
if self.step == timedelta(days=0):
return f"{fmt(self.start)}"
if self.step == timedelta(days=1):
return f"{fmt(self.start)}/to/{fmt(self.end)}"
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step // timedelta(days=1)}"
@dataclass(frozen=True)
class TimeRange(Range):
start: int
end: int
step: int
dtype: Literal["time"] = dataclasses.field(kw_only=True, default="time")
@classmethod
def from_strings(self, values: list[str]) -> list['TimeRange']:
if len(values) == 0: return []
times = sorted([int(v) for v in values])
if len(times) < 2:
return [TimeRange(
start=times[0],
end=times[0],
step=100
)]
ranges = []
current_range, times = [times[0],], times[1:]
while len(times) > 1:
if times[0] - current_range[-1] == 1:
current_range.append(times.pop(0))
elif len(current_range) == 1:
ranges.append(TimeRange(
start=current_range[0],
end=current_range[0],
step=0
))
current_range = [times.pop(0),]
else:
ranges.append(TimeRange(
start=current_range[0],
end=current_range[-1],
step=1
))
current_range = [times.pop(0),]
return ranges
def __len__(self) -> int:
return (self.end - self.start) // self.step
def summary(self) -> str:
def fmt(d): return f"{d:04d}"
if self.step == 0:
return f"{fmt(self.start)}"
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}"
def __contains__(self, value: Any) -> bool:
v = int(value)
return self.start <= v <= self.end and (v - self.start) % self.step == 0
@dataclass(frozen=True)
class IntRange(Range):
dtype: Literal["int"]
start: int
end: int
step: int
dtype: Literal["int"] = dataclasses.field(kw_only=True, default="int")
def __len__(self) -> int:
return (self.end - self.start) // self.step
def summary(self) -> str:
def fmt(d): return d.strftime("%Y%m%d")
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}"
def __contains__(self, value: Any) -> bool:
v = int(value)
return self.start <= v <= self.end and (v - self.start) % self.step == 0
def values_from_json(obj) -> Values:
if isinstance(obj, list):
return Enum(obj)
match obj["dtype"]:
case "date": return DateRange(**obj)
case "time": return TimeRange(**obj)
case "int": return IntRange(**obj)
case _: raise ValueError(f"Unknown dtype {obj['dtype']}")
@dataclass(frozen=True)
class Node:
key: str
values: Values # Must support len()
metadata: dict[str, str] # Applies to all children
payload: list[Any] # List of size product(len(n.values) for n in ancestors(self))
children: list['Node']
def summarize_node(node: Node) -> tuple[str, Node]:
"""
Extracts a summarized representation of the node while collapsing single-child paths.
Returns the summary string and the last node in the chain that has multiple children.
"""
summary = []
while True:
values_summary = node.values.summary()
if len(values_summary) > 50:
values_summary = values_summary[:50] + "..."
summary.append(f"{node.key}={values_summary}")
# Move down if there's exactly one child, otherwise stop
if len(node.children) != 1:
break
node = node.children[0]
return ", ".join(summary), node
def node_tree_to_string(node : Node, prefix : str = "", depth = None) -> Iterable[str]:
summary, node = summarize_node(node)
if depth is not None and depth <= 0:
yield summary + " - ...\n"
return
# Special case for nodes with only a single child, this makes the printed representation more compact
elif len(node.children) == 1:
yield summary + ", "
yield from node_tree_to_string(node.children[0], prefix, depth = depth)
return
else:
yield summary + "\n"
for index, child in enumerate(node.children):
connector = "└── " if index == len(node.children) - 1 else "├── "
yield prefix + connector
extension = " " if index == len(node.children) - 1 else ""
yield from node_tree_to_string(child, prefix + extension, depth = depth - 1 if depth is not None else None)
def node_tree_to_html(node : Node, prefix : str = "", depth = 1, connector = "") -> Iterable[str]:
summary, node = summarize_node(node)
if len(node.children) == 0:
yield f'<span class="leaf">{connector}{summary}</span>'
return
else:
open = "open" if depth > 0 else ""
yield f"<details {open}><summary>{connector}{summary}</summary>"
for index, child in enumerate(node.children):
connector = "└── " if index == len(node.children) - 1 else "├── "
extension = " " if index == len(node.children) - 1 else ""
yield from node_tree_to_html(child, prefix + extension, depth = depth - 1, connector = prefix+connector)
yield "</details>"
@dataclass(frozen=True)
class CompressedTree:
root: Node
@classmethod
def from_json(cls, json: dict) -> 'CompressedTree':
def from_json(json: dict) -> Node:
return Node(
key=json["key"],
values=values_from_json(json["values"]),
metadata=json["metadata"] if "metadata" in json else {},
payload=json["payload"] if "payload" in json else [],
children=[from_json(c) for c in json["children"]]
)
return CompressedTree(root=from_json(json))
def __str__(self):
return "".join(node_tree_to_string(node=self.root))
def html(self, depth = 2) -> HTML:
return HTML(self._repr_html_(depth = depth))
def _repr_html_(self, depth = 2):
css = """
<style>
.qubed-tree-view {
font-family: monospace;
white-space: pre;
}
.qubed-tree-view details {
# display: inline;
margin-left: 0;
}
.qubed-tree-view summary {
list-style: none;
cursor: pointer;
text-overflow: ellipsis;
overflow: hidden;
text-wrap: nowrap;
display: block;
}
.qubed-tree-view .leaf {
text-overflow: ellipsis;
overflow: hidden;
text-wrap: nowrap;
display: block;
}
.qubed-tree-view summary:hover,span.leaf:hover {
background-color: #f0f0f0;
}
.qubed-tree-view details > summary::after {
content: ' ';
}
.qubed-tree-view details:not([open]) > summary::after {
content: "";
}
</style>
"""
nodes = "".join(cc for c in self.root.children for cc in node_tree_to_html(node=c, depth=depth))
return f"{css}<pre class='qubed-tree-view'>{nodes}</pre>"
def print(self, depth = None):
print("".join(cc for c in self.root.children for cc in node_tree_to_string(node=c, depth = depth)))
def transform(self, func: Callable[[Node], Node]) -> 'CompressedTree':
"Call a function on every node of the tree, any changes to the children of a node will be ignored."
def transform(node: Node) -> Node:
new_node = func(node)
return dataclasses.replace(new_node, children = [transform(c) for c in node.children])
return CompressedTree(root=transform(self.root))
def guess_datatypes(self) -> 'CompressedTree':
def guess_datatypes(node: Node) -> list[Node]:
# Try to convert enum values into more structured types
children = [cc for c in node.children for cc in guess_datatypes(c)]
if isinstance(node.values, Enum):
match node.key:
case "time": range_class = TimeRange
case "date": range_class = DateRange
case _: range_class = None
if range_class is not None:
return [
dataclasses.replace(node, values = range, children = children)
for range in range_class.from_strings(node.values.values)
]
return [dataclasses.replace(node, children = children)]
children = [cc for c in self.root.children for cc in guess_datatypes(c)]
return CompressedTree(root=dataclasses.replace(self.root, children = children))
def select(self, selection : dict[str, str | list[str]], mode: Literal["strict", "relaxed"] = "relaxed") -> 'CompressedTree':
# make all values lists
selection = {k : v if isinstance(v, list) else [v] for k,v in selection.items()}
def not_none(xs): return [x for x in xs if x is not None]
def select(node: Node) -> Node | None:
# Check if the key is specified in the selection
if node.key not in selection:
if mode == "strict":
return None
return dataclasses.replace(node, children = not_none(select(c) for c in node.children))
# If the key is specified, check if any of the values match
values = Enum([ c for c in selection[node.key] if c in node.values])
if not values:
return None
return dataclasses.replace(node, values = values, children = not_none(select(c) for c in node.children))
return CompressedTree(root=dataclasses.replace(self.root, children = not_none(select(c) for c in self.root.children)))
def to_list_of_cubes(self):
def to_list_of_cubes(node: Node) -> list[list[Node]]:
return [[node] + sub_cube for c in node.children for sub_cube in to_list_of_cubes(c)]
return to_list_of_cubes(self.root)
def info(self):
cubes = self.to_list_of_cubes()
print(f"Number of distinct paths: {len(cubes)}")
# What should the interace look like?
# tree = CompressedTree.from_json(...)
# tree = CompressedTree.from_protobuf(...)
# tree.print(depth = 5) # Prints a nice tree representation