import dataclasses from dataclasses import dataclass, field from typing import Any, Callable, Hashable, Literal, Mapping from frozendict import frozendict from .tree_formatters import HTML, node_tree_to_html, node_tree_to_string from .value_types import DateRange, Enum, IntRange, TimeRange, Values def values_from_json(obj) -> Values: if isinstance(obj, list): return Enum(tuple(obj)) match obj["dtype"]: case "date": return DateRange(**obj) case "time": return TimeRange(**obj) case "int": return IntRange(**obj) case _: raise ValueError(f"Unknown dtype {obj['dtype']}") # In practice use a frozendict Metadata = Mapping[str, str | int | float | bool] @dataclass(frozen=True, eq=True, order=True) class NodeData: key: str values: Values metadata: dict[str, tuple[Hashable, ...]] = field(default_factory=frozendict, compare=False) def summary(self) -> str: return f"{self.key}={self.values.summary()}" if self.key != "root" else "root" @dataclass(frozen=True, eq=True, order=True) class Tree: data: NodeData children: tuple['Tree', ...] @property def key(self) -> str: return self.data.key @property def values(self) -> Values: return self.data.values @property def metadata(self) -> frozendict[str, Any]: return self.data.metadata def summary(self) -> str: return self.data.summary() @classmethod def make(cls, key : str, values : Values, children, **kwargs) -> 'Tree': return cls( data = NodeData(key, values, metadata = kwargs.get("metadata", frozendict()) ), children = tuple(sorted(children)), ) @classmethod def from_json(cls, json: dict) -> 'Tree': def from_json(json: dict) -> Tree: return Tree.make( key=json["key"], values=values_from_json(json["values"]), metadata=json["metadata"] if "metadata" in json else {}, children=tuple(from_json(c) for c in json["children"]) ) return from_json(json) @classmethod def from_dict(cls, d: dict) -> 'Tree': def from_dict(d: dict) -> tuple[Tree, ...]: return tuple(Tree.make( key=k.split("=")[0], values=Enum(tuple(k.split("=")[1].split("/"))), children=from_dict(children) ) for k, children in d.items()) return Tree.make(key = "root", values=Enum(("root",)), children = from_dict(d)) @classmethod def empty(cls) -> 'Tree': return cls.make("root", Enum(("root",)), []) def __str__(self): return "".join(node_tree_to_string(node=self)) def html(self, depth = 2, collapse = True) -> HTML: return HTML(node_tree_to_html(self, depth = depth, collapse = collapse)) def _repr_html_(self) -> str: return node_tree_to_html(self, depth = 2, collapse = True) def __getitem__(self, args) -> 'Tree': key, value = args for c in self.children: if c.key == key and value in c.values: data = dataclasses.replace(c.data, values = Enum((value,))) return dataclasses.replace(c, data = data) raise KeyError(f"Key {key} not found in children of {self.key}") def print(self, depth = None): print("".join(cc for c in self.children for cc in node_tree_to_string(node=c, depth = depth))) def transform(self, func: 'Callable[[Tree], Tree | list[Tree]]') -> 'Tree': """ Call a function on every node of the tree, return one or more nodes. If multiple nodes are returned they each get a copy of the (transformed) children of the original node. Any changes to the children of a node will be ignored. """ def transform(node: Tree) -> list[Tree]: children = [cc for c in node.children for cc in transform(c)] new_nodes = func(node) if isinstance(new_nodes, Tree): new_nodes = [new_nodes] return [dataclasses.replace(new_node, children = children) for new_node in new_nodes] children = tuple(cc for c in self.children for cc in transform(c)) return dataclasses.replace(self, children = children) def guess_datatypes(self) -> 'Tree': def guess_datatypes(node: Tree) -> list[Tree]: # Try to convert enum values into more structured types children = tuple(cc for c in node.children for cc in guess_datatypes(c)) if isinstance(node.values, Enum): match node.key: case "time": range_class = TimeRange case "date": range_class = DateRange case _: range_class = None if range_class is not None: return [ dataclasses.replace(node, values = range, children = children) for range in range_class.from_strings(node.values.values) ] return [dataclasses.replace(node, children = children)] children = tuple(cc for c in self.children for cc in guess_datatypes(c)) return dataclasses.replace(self, children = children) def select(self, selection : dict[str, str | list[str]], mode: Literal["strict", "relaxed"] = "relaxed") -> 'Tree': # make all values lists selection = {k : v if isinstance(v, list) else [v] for k,v in selection.items()} def not_none(xs): return tuple(x for x in xs if x is not None) def select(node: Tree) -> Tree | None: # Check if the key is specified in the selection if node.key not in selection: if mode == "strict": return None return dataclasses.replace(node, children = not_none(select(c) for c in node.children)) # If the key is specified, check if any of the values match values = Enum(tuple(c for c in selection[node.key] if c in node.values)) if not values: return None return dataclasses.replace(node, values = values, children = not_none(select(c) for c in node.children)) return dataclasses.replace(self, children = not_none(select(c) for c in self.children)) @staticmethod def _insert(position: "Tree", identifier : list[tuple[str, list[str]]]): """ This algorithm goes as follows: We're at a particular node in the tree, and we have a list of key-values pairs that we want to insert. We take the first key values pair key, values = identifier.pop(0) The general idea is to insert key, values into the current node and use recursion to handle the rest of the identifier. We have two sources of values with possible overlap. The values to insert and the values attached to the children of this node. For each value coming from either source we put it in one of three categories: 1) Values that exist only in the already existing child. (Coming exclusively from position.children) 2) Values that exist in both a child and the new values. 3) Values that exist only in the new values. Thus we add the values to insert to a set, and loop over the children. For each child we partition its values into the three categories. For 1) we create a new child node with the key, reduced set of values and the same children. For 2) Create a new child node with the key, and the values in group 2 Recurse to compute the children Once we have finished looping over children we know all the values left over came exclusively from the new values. So we: Create a new node with these values. Recurse to compute the children Finally we return the node with all these new children. """ if not identifier: return position key, values = identifier.pop(0) # print(f"Inserting {key}={values} into {position.summary()}") # Determine which children have this key possible_children = {c : [] for c in position.children if c.key == key} entirely_new_values = [] # For each value check it is already in one of the children for v in values: for c in possible_children: if v in c.values: possible_children[c].append(v) break else: # only executed if the loop did not break # If none of the children have this value, add it to the new child pile entirely_new_values.append(v) # d = {p.summary() : v for p, v in possible_children.items()} # print(f" {d} new_values={entirely_new_values}") new_children = [] for c, affected in possible_children.items(): if not affected: new_children.append(c) continue unaffected = [x for x in c.values if x not in affected] if unaffected: unaffected_node = Tree.make(c.key, Enum(tuple(unaffected)), c.children) new_children.append(unaffected_node) # Add the unaffected part of this child if affected: # This check is not technically necessary, but it makes the code more readable new_node = Tree.make(key, Enum(tuple(affected)), []) new_node = Tree._insert(new_node, identifier) new_children.append(new_node) # Add the affected part of this child # If there are any values not in any of the existing children, add them as a new child if entirely_new_values: new_node = Tree.make(key, Enum(tuple(entirely_new_values)), []) new_children.append(Tree._insert(new_node, identifier)) return Tree.make(position.key, position.values, new_children) def insert(self, identifier : dict[str, list[str]]) -> 'Tree': insertion = [(k, v) for k, v in identifier.items()] return Tree._insert(self, insertion) def to_list_of_cubes(self): def to_list_of_cubes(node: Tree) -> list[list[Tree]]: return [[node] + sub_cube for c in node.children for sub_cube in to_list_of_cubes(c)] return to_list_of_cubes(self) def info(self): cubes = self.to_list_of_cubes() print(f"Number of distinct paths: {len(cubes)}")