import dataclasses from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import date, datetime, timedelta from typing import Any, Callable, Iterable, Literal @dataclass(frozen=True) class HTML(): html: str def _repr_html_(self): return self.html @dataclass(frozen=True) class Values(ABC): @abstractmethod def summary(self) -> str: pass @abstractmethod def __len__(self) -> int: pass @abstractmethod def __contains__(self, value: Any) -> bool: pass @abstractmethod def from_strings(self, values: list[str]) -> list['Values']: pass @dataclass(frozen=True) class Enum(Values): """ The simplest kind of key value is just a list of strings. summary -> string1/string2/string.... """ values: list[Any] def __len__(self) -> int: return len(self.values) def summary(self) -> str: return '/'.join(sorted(self.values)) def __contains__(self, value: Any) -> bool: return value in self.values def from_strings(self, values: list[str]) -> list['Values']: return [Enum(values)] @dataclass(frozen=True) class Range(Values, ABC): dtype: str = dataclasses.field(kw_only=True) @dataclass(frozen=True) class DateRange(Range): start: date end: date step: timedelta dtype: Literal["date"] = dataclasses.field(kw_only=True, default="date") @classmethod def from_strings(self, values: list[str]) -> list['DateRange']: dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values]) if len(dates) < 2: return [DateRange( start=dates[0], end=dates[0], step=timedelta(days=0) )] ranges = [] current_range, dates = [dates[0],], dates[1:] while len(dates) > 1: if dates[0] - current_range[-1] == timedelta(days=1): current_range.append(dates.pop(0)) elif len(current_range) == 1: ranges.append(DateRange( start=current_range[0], end=current_range[0], step=timedelta(days=0) )) current_range = [dates.pop(0),] else: ranges.append(DateRange( start=current_range[0], end=current_range[-1], step=timedelta(days=1) )) current_range = [dates.pop(0),] return ranges def __contains__(self, value: Any) -> bool: v = datetime.strptime(value, "%Y%m%d").date() return self.start <= v <= self.end and (v - self.start) % self.step == 0 def __len__(self) -> int: return (self.end - self.start) // self.step def summary(self) -> str: def fmt(d): return d.strftime("%Y%m%d") if self.step == timedelta(days=0): return f"{fmt(self.start)}" if self.step == timedelta(days=1): return f"{fmt(self.start)}/to/{fmt(self.end)}" return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step // timedelta(days=1)}" @dataclass(frozen=True) class TimeRange(Range): start: int end: int step: int dtype: Literal["time"] = dataclasses.field(kw_only=True, default="time") @classmethod def from_strings(self, values: list[str]) -> list['TimeRange']: if len(values) == 0: return [] times = sorted([int(v) for v in values]) if len(times) < 2: return [TimeRange( start=times[0], end=times[0], step=100 )] ranges = [] current_range, times = [times[0],], times[1:] while len(times) > 1: if times[0] - current_range[-1] == 1: current_range.append(times.pop(0)) elif len(current_range) == 1: ranges.append(TimeRange( start=current_range[0], end=current_range[0], step=0 )) current_range = [times.pop(0),] else: ranges.append(TimeRange( start=current_range[0], end=current_range[-1], step=1 )) current_range = [times.pop(0),] return ranges def __len__(self) -> int: return (self.end - self.start) // self.step def summary(self) -> str: def fmt(d): return f"{d:04d}" if self.step == 0: return f"{fmt(self.start)}" return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}" def __contains__(self, value: Any) -> bool: v = int(value) return self.start <= v <= self.end and (v - self.start) % self.step == 0 @dataclass(frozen=True) class IntRange(Range): dtype: Literal["int"] start: int end: int step: int dtype: Literal["int"] = dataclasses.field(kw_only=True, default="int") def __len__(self) -> int: return (self.end - self.start) // self.step def summary(self) -> str: def fmt(d): return d.strftime("%Y%m%d") return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}" def __contains__(self, value: Any) -> bool: v = int(value) return self.start <= v <= self.end and (v - self.start) % self.step == 0 def values_from_json(obj) -> Values: if isinstance(obj, list): return Enum(obj) match obj["dtype"]: case "date": return DateRange(**obj) case "time": return TimeRange(**obj) case "int": return IntRange(**obj) case _: raise ValueError(f"Unknown dtype {obj['dtype']}") @dataclass(frozen=True) class Node: key: str values: Values # Must support len() metadata: dict[str, str] # Applies to all children payload: list[Any] # List of size product(len(n.values) for n in ancestors(self)) children: list['Node'] def summarize_node(node: Node) -> tuple[str, Node]: """ Extracts a summarized representation of the node while collapsing single-child paths. Returns the summary string and the last node in the chain that has multiple children. """ summary = [] while True: values_summary = node.values.summary() if len(values_summary) > 50: values_summary = values_summary[:50] + "..." summary.append(f"{node.key}={values_summary}") # Move down if there's exactly one child, otherwise stop if len(node.children) != 1: break node = node.children[0] return ", ".join(summary), node def node_tree_to_string(node : Node, prefix : str = "", depth = None) -> Iterable[str]: summary, node = summarize_node(node) if depth is not None and depth <= 0: yield summary + " - ...\n" return # Special case for nodes with only a single child, this makes the printed representation more compact elif len(node.children) == 1: yield summary + ", " yield from node_tree_to_string(node.children[0], prefix, depth = depth) return else: yield summary + "\n" for index, child in enumerate(node.children): connector = "└── " if index == len(node.children) - 1 else "├── " yield prefix + connector extension = " " if index == len(node.children) - 1 else "│ " yield from node_tree_to_string(child, prefix + extension, depth = depth - 1 if depth is not None else None) def node_tree_to_html(node : Node, prefix : str = "", depth = 1, connector = "") -> Iterable[str]: summary, node = summarize_node(node) if len(node.children) == 0: yield f'{connector}{summary}' return else: open = "open" if depth > 0 else "" yield f"
{connector}{summary}" for index, child in enumerate(node.children): connector = "└── " if index == len(node.children) - 1 else "├── " extension = " " if index == len(node.children) - 1 else "│ " yield from node_tree_to_html(child, prefix + extension, depth = depth - 1, connector = prefix+connector) yield "
" @dataclass(frozen=True) class CompressedTree: root: Node @classmethod def from_json(cls, json: dict) -> 'CompressedTree': def from_json(json: dict) -> Node: return Node( key=json["key"], values=values_from_json(json["values"]), metadata=json["metadata"] if "metadata" in json else {}, payload=json["payload"] if "payload" in json else [], children=[from_json(c) for c in json["children"]] ) return CompressedTree(root=from_json(json)) def __str__(self): return "".join(node_tree_to_string(node=self.root)) def html(self, depth = 2) -> HTML: return HTML(self._repr_html_(depth = depth)) def _repr_html_(self, depth = 2): css = """ """ nodes = "".join(cc for c in self.root.children for cc in node_tree_to_html(node=c, depth=depth)) return f"{css}
{nodes}
" def print(self, depth = None): print("".join(cc for c in self.root.children for cc in node_tree_to_string(node=c, depth = depth))) def transform(self, func: Callable[[Node], Node]) -> 'CompressedTree': "Call a function on every node of the tree, any changes to the children of a node will be ignored." def transform(node: Node) -> Node: new_node = func(node) return dataclasses.replace(new_node, children = [transform(c) for c in node.children]) return CompressedTree(root=transform(self.root)) def guess_datatypes(self) -> 'CompressedTree': def guess_datatypes(node: Node) -> list[Node]: # Try to convert enum values into more structured types children = [cc for c in node.children for cc in guess_datatypes(c)] if isinstance(node.values, Enum): match node.key: case "time": range_class = TimeRange case "date": range_class = DateRange case _: range_class = None if range_class is not None: return [ dataclasses.replace(node, values = range, children = children) for range in range_class.from_strings(node.values.values) ] return [dataclasses.replace(node, children = children)] children = [cc for c in self.root.children for cc in guess_datatypes(c)] return CompressedTree(root=dataclasses.replace(self.root, children = children)) def select(self, selection : dict[str, str | list[str]], mode: Literal["strict", "relaxed"] = "relaxed") -> 'CompressedTree': # make all values lists selection = {k : v if isinstance(v, list) else [v] for k,v in selection.items()} def not_none(xs): return [x for x in xs if x is not None] def select(node: Node) -> Node | None: # Check if the key is specified in the selection if node.key not in selection: if mode == "strict": return None return dataclasses.replace(node, children = not_none(select(c) for c in node.children)) # If the key is specified, check if any of the values match values = Enum([ c for c in selection[node.key] if c in node.values]) if not values: return None return dataclasses.replace(node, values = values, children = not_none(select(c) for c in node.children)) return CompressedTree(root=dataclasses.replace(self.root, children = not_none(select(c) for c in self.root.children))) def to_list_of_cubes(self): def to_list_of_cubes(node: Node) -> list[list[Node]]: return [[node] + sub_cube for c in node.children for sub_cube in to_list_of_cubes(c)] return to_list_of_cubes(self.root) def info(self): cubes = self.to_list_of_cubes() print(f"Number of distinct paths: {len(cubes)}") # What should the interace look like? # tree = CompressedTree.from_json(...) # tree = CompressedTree.from_protobuf(...) # tree.print(depth = 5) # Prints a nice tree representation