import dataclasses from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import date, datetime, timedelta from typing import Any, Callable, Iterable, Literal @dataclass(frozen=True) class HTML: html: str def _repr_html_(self): return self.html @dataclass(frozen=True) class Values(ABC): @abstractmethod def summary(self) -> str: pass @abstractmethod def __len__(self) -> int: pass @abstractmethod def __contains__(self, value: Any) -> bool: pass @abstractmethod def from_strings(self, values: list[str]) -> list["Values"]: pass @dataclass(frozen=True) class Enum(Values): """ The simplest kind of key value is just a list of strings. summary -> string1/string2/string.... """ values: list[Any] def __len__(self) -> int: return len(self.values) def summary(self) -> str: return "/".join(sorted(self.values)) def __contains__(self, value: Any) -> bool: return value in self.values def from_strings(self, values: list[str]) -> list["Values"]: return [Enum(values)] @dataclass(frozen=True) class Range(Values, ABC): dtype: str = dataclasses.field(kw_only=True) @dataclass(frozen=True) class DateRange(Range): start: date end: date step: timedelta dtype: Literal["date"] = dataclasses.field(kw_only=True, default="date") @classmethod def from_strings(self, values: list[str]) -> list["DateRange"]: dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values]) if len(dates) < 2: return [DateRange(start=dates[0], end=dates[0], step=timedelta(days=0))] ranges = [] current_range, dates = ( [ dates[0], ], dates[1:], ) while len(dates) > 1: if dates[0] - current_range[-1] == timedelta(days=1): current_range.append(dates.pop(0)) elif len(current_range) == 1: ranges.append( DateRange( start=current_range[0], end=current_range[0], step=timedelta(days=0), ) ) current_range = [ dates.pop(0), ] else: ranges.append( DateRange( start=current_range[0], end=current_range[-1], step=timedelta(days=1), ) ) current_range = [ dates.pop(0), ] return ranges def __contains__(self, value: Any) -> bool: v = datetime.strptime(value, "%Y%m%d").date() return self.start <= v <= self.end and (v - self.start) % self.step == 0 def __len__(self) -> int: return (self.end - self.start) // self.step def summary(self) -> str: def fmt(d): return d.strftime("%Y%m%d") if self.step == timedelta(days=0): return f"{fmt(self.start)}" if self.step == timedelta(days=1): return f"{fmt(self.start)}/to/{fmt(self.end)}" return ( f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step // timedelta(days=1)}" ) @dataclass(frozen=True) class TimeRange(Range): start: int end: int step: int dtype: Literal["time"] = dataclasses.field(kw_only=True, default="time") @classmethod def from_strings(self, values: list[str]) -> list["TimeRange"]: if len(values) == 0: return [] times = sorted([int(v) for v in values]) if len(times) < 2: return [TimeRange(start=times[0], end=times[0], step=100)] ranges = [] current_range, times = ( [ times[0], ], times[1:], ) while len(times) > 1: if times[0] - current_range[-1] == 1: current_range.append(times.pop(0)) elif len(current_range) == 1: ranges.append( TimeRange(start=current_range[0], end=current_range[0], step=0) ) current_range = [ times.pop(0), ] else: ranges.append( TimeRange(start=current_range[0], end=current_range[-1], step=1) ) current_range = [ times.pop(0), ] return ranges def __len__(self) -> int: return (self.end - self.start) // self.step def summary(self) -> str: def fmt(d): return f"{d:04d}" if self.step == 0: return f"{fmt(self.start)}" return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}" def __contains__(self, value: Any) -> bool: v = int(value) return self.start <= v <= self.end and (v - self.start) % self.step == 0 @dataclass(frozen=True) class IntRange(Range): dtype: Literal["int"] start: int end: int step: int dtype: Literal["int"] = dataclasses.field(kw_only=True, default="int") def __len__(self) -> int: return (self.end - self.start) // self.step def summary(self) -> str: def fmt(d): return d.strftime("%Y%m%d") return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}" def __contains__(self, value: Any) -> bool: v = int(value) return self.start <= v <= self.end and (v - self.start) % self.step == 0 def values_from_json(obj) -> Values: if isinstance(obj, list): return Enum(obj) match obj["dtype"]: case "date": return DateRange(**obj) case "time": return TimeRange(**obj) case "int": return IntRange(**obj) case _: raise ValueError(f"Unknown dtype {obj['dtype']}") @dataclass(frozen=True) class Node: key: str values: Values # Must support len() metadata: dict[str, str] # Applies to all children payload: list[Any] # List of size product(len(n.values) for n in ancestors(self)) children: list["Node"] def summarize_node(node: Node) -> tuple[str, Node]: """ Extracts a summarized representation of the node while collapsing single-child paths. Returns the summary string and the last node in the chain that has multiple children. """ summary = [] while True: values_summary = node.values.summary() if len(values_summary) > 50: values_summary = values_summary[:50] + "..." summary.append(f"{node.key}={values_summary}") # Move down if there's exactly one child, otherwise stop if len(node.children) != 1: break node = node.children[0] return ", ".join(summary), node def node_tree_to_string(node: Node, prefix: str = "", depth=None) -> Iterable[str]: summary, node = summarize_node(node) if depth is not None and depth <= 0: yield summary + " - ...\n" return # Special case for nodes with only a single child, this makes the printed representation more compact elif len(node.children) == 1: yield summary + ", " yield from node_tree_to_string(node.children[0], prefix, depth=depth) return else: yield summary + "\n" for index, child in enumerate(node.children): connector = "└── " if index == len(node.children) - 1 else "├── " yield prefix + connector extension = " " if index == len(node.children) - 1 else "│ " yield from node_tree_to_string( child, prefix + extension, depth=depth - 1 if depth is not None else None ) def node_tree_to_html( node: Node, prefix: str = "", depth=1, connector="" ) -> Iterable[str]: summary, node = summarize_node(node) if len(node.children) == 0: yield f'{connector}{summary}' return else: open = "open" if depth > 0 else "" yield f"
{connector}{summary}" for index, child in enumerate(node.children): connector = "└── " if index == len(node.children) - 1 else "├── " extension = " " if index == len(node.children) - 1 else "│ " yield from node_tree_to_html( child, prefix + extension, depth=depth - 1, connector=prefix + connector ) yield "
" @dataclass(frozen=True) class CompressedTree: root: Node @classmethod def from_json(cls, json: dict) -> "CompressedTree": def from_json(json: dict) -> Node: return Node( key=json["key"], values=values_from_json(json["values"]), metadata=json["metadata"] if "metadata" in json else {}, payload=json["payload"] if "payload" in json else [], children=[from_json(c) for c in json["children"]], ) return CompressedTree(root=from_json(json)) def __str__(self): return "".join(node_tree_to_string(node=self.root)) def html(self, depth=2) -> HTML: return HTML(self._repr_html_(depth=depth)) def _repr_html_(self, depth=2): css = """ """ nodes = "".join( cc for c in self.root.children for cc in node_tree_to_html(node=c, depth=depth) ) return f"{css}
{nodes}
" def print(self, depth=None): print( "".join( cc for c in self.root.children for cc in node_tree_to_string(node=c, depth=depth) ) ) def transform(self, func: Callable[[Node], Node]) -> "CompressedTree": "Call a function on every node of the tree, any changes to the children of a node will be ignored." def transform(node: Node) -> Node: new_node = func(node) return dataclasses.replace( new_node, children=[transform(c) for c in node.children] ) return CompressedTree(root=transform(self.root)) def guess_datatypes(self) -> "CompressedTree": def guess_datatypes(node: Node) -> list[Node]: # Try to convert enum values into more structured types children = [cc for c in node.children for cc in guess_datatypes(c)] if isinstance(node.values, Enum): match node.key: case "time": range_class = TimeRange case "date": range_class = DateRange case _: range_class = None if range_class is not None: return [ dataclasses.replace(node, values=range, children=children) for range in range_class.from_strings(node.values.values) ] return [dataclasses.replace(node, children=children)] children = [cc for c in self.root.children for cc in guess_datatypes(c)] return CompressedTree(root=dataclasses.replace(self.root, children=children)) def select( self, selection: dict[str, str | list[str]], mode: Literal["strict", "relaxed"] = "relaxed", ) -> "CompressedTree": # make all values lists selection = {k: v if isinstance(v, list) else [v] for k, v in selection.items()} def not_none(xs): return [x for x in xs if x is not None] def select(node: Node) -> Node | None: # Check if the key is specified in the selection if node.key not in selection: if mode == "strict": return None return dataclasses.replace( node, children=not_none(select(c) for c in node.children) ) # If the key is specified, check if any of the values match values = Enum([c for c in selection[node.key] if c in node.values]) if not values: return None return dataclasses.replace( node, values=values, children=not_none(select(c) for c in node.children) ) return CompressedTree( root=dataclasses.replace( self.root, children=not_none(select(c) for c in self.root.children) ) ) def to_list_of_cubes(self): def to_list_of_cubes(node: Node) -> list[list[Node]]: return [ [node] + sub_cube for c in node.children for sub_cube in to_list_of_cubes(c) ] return to_list_of_cubes(self.root) def info(self): cubes = self.to_list_of_cubes() print(f"Number of distinct paths: {len(cubes)}") # What should the interace look like? # tree = CompressedTree.from_json(...) # tree = CompressedTree.from_protobuf(...) # tree.print(depth = 5) # Prints a nice tree representation