diff --git a/tree_compresser/Cargo.lock b/tree_compresser/Cargo.lock index 8c23cac..dc5c396 100644 --- a/tree_compresser/Cargo.lock +++ b/tree_compresser/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aho-corasick" @@ -359,7 +359,6 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rsfdb" version = "0.1.0" -source = "git+https://github.com/ecmwf/rsfdb?branch=develop#ab8c9590bba15d22167c274db9238cd9b897baf1" dependencies = [ "libc", "libloading", @@ -372,7 +371,6 @@ dependencies = [ [[package]] name = "rsfindlibs" version = "0.1.1" -source = "git+https://github.com/ecmwf-projects/rsfindlibs.git#1358b1049bf3e0b581badfc8005a9828a542cdaa" dependencies = [ "cc", "clap", diff --git a/tree_compresser/Cargo.toml b/tree_compresser/Cargo.toml index bf42c02..37ed7af 100644 --- a/tree_compresser/Cargo.toml +++ b/tree_compresser/Cargo.toml @@ -1,7 +1,8 @@ [package] -name = "qubed_tree" -version = "0.1.0" +name = "qubed" +version = "0.1.2" edition = "2021" +repository = "https://github.com/ecmwf/qubed" [dependencies] rsfdb = {git = "https://github.com/ecmwf/rsfdb", branch = "develop"} @@ -16,7 +17,7 @@ crate-type = ["cdylib"] path = "./rust_src/lib.rs" [patch.'https://github.com/ecmwf/rsfdb'] -rsfdb = { path = "../rsfdb" } +rsfdb = { path = "../../rsfdb" } [patch.'https://github.com/ecmwf-projects/rsfindlibs'] -rsfindlibs = { path = "../rsfindlibs" } \ No newline at end of file +rsfindlibs = { path = "../../rsfindlibs" } \ No newline at end of file diff --git a/tree_compresser/python_src/tree_traverser/CompressedDataCubeTree.py b/tree_compresser/python_src/tree_traverser/CompressedDataCubeTree.py new file mode 100644 index 0000000..82bc869 --- /dev/null +++ b/tree_compresser/python_src/tree_traverser/CompressedDataCubeTree.py @@ -0,0 +1,216 @@ +import dataclasses +from collections import defaultdict +from dataclasses import dataclass, field + +from frozendict import frozendict + +from .DataCubeTree import Enum, NodeData, Tree +from .tree_formatters import HTML, node_tree_to_html, node_tree_to_string + +NodeId = int +CacheType = dict[NodeId, "CompressedNode"] + +@dataclass(frozen=True) +class CompressedNode: + id: NodeId = field(hash=False, compare=False) + data: NodeData + + _children: tuple[NodeId, ...] + _cache: CacheType = field(repr=False, hash=False, compare=False) + + @property + def children(self) -> tuple["CompressedNode", ...]: + return tuple(self._cache[i] for i in self._children) + + def summary(self, debug = False) -> str: + if debug: return f"{self.data.key}={self.data.values.summary()} ({self.id})" + return f"{self.data.key}={self.data.values.summary()}" if self.data.key != "root" else "root" + + +@dataclass(frozen=True) +class CompressedTree: + """ + This tree is compressed in two distinct different ways: + 1. Product Compression: Nodes have a key and **multiple values**, so each node represents many logical nodes key=value1, key=value2, ... + Each of these logical nodes is has identical children so we can compress them like this. + In this way any distinct path through the tree represents a cartesian product of the values, otherwise known as a datacube. + + 2. In order to facilitate the product compression described above we need to know when two nodes have identical children. + To do this every node is assigned an Id which is initially computed as a hash from the nodes data and its childrens' ids. + In order to avoid hash collisions we increment the initial hash if it's already in the cache for a different node + we do this until we find a unique id. + + Crucially this allows us to later determine if a new node is already cached: + id = hash(node) + while True: + if id not in cache: The node is definitely not in the cache + elif cache[id] != node: Hash collision, increment id and try again + else: The node is already in the cache + id += 1 + + This tree can be walked from the root by repeatedly looking up the children of a node in the cache. + + This structure facilitates compression because we can look at the children of a node: + If two chidren have the same key, metadata and children then we can compress them into a single node. + +""" + root: CompressedNode + cache: CacheType + + @staticmethod + def add_to_cache(cache : dict[NodeId, CompressedNode], data : NodeData, _children: tuple[NodeId, ...]) -> NodeId: + """ + This function is responsible for adding a new node to the cache and returning its id. + Crucially we need a way to check if new nodes are already in the cache, so we hash them. + But in case of a hash collision we need to increment the id and try again. + This way we will always eventually find a unique id for the node. + And we will never store the same node twice with a different id. + """ + _children = tuple(sorted(_children)) + id = hash((data, _children)) + + # To avoid hash collisions, we increment the id until we find a unique one + tries = 0 + while True: + tries += 1 + if id not in cache: + # The node isn't in the cache and this id is free + cache[id] = CompressedNode(id = id, + data = data, + _children = _children, + _cache = cache) + break + + if cache[id].data == data and cache[id]._children == _children: + break # The node is already in the cache + + # This id is already in use by a different node so increment it (mod) and try again + id = (id + 1) % (2**64) + + if tries > 100: + raise RuntimeError("Too many hash collisions, something is wrong.") + + return id + + + @classmethod + def from_tree(cls, tree : Tree) -> 'CompressedTree': + cache = {} + + def cache_tree(level : Tree) -> NodeId: + node_data = NodeData( + key = level.key, + values = level.values, + ) + + # Recursively cache the children + children = tuple(cache_tree(c) for c in level.children) + + # Add the node to the cache and return its id + return cls.add_to_cache(cache, node_data, children) + + root = cache_tree(tree) + return cls(cache = cache, root = cache[root]) + + def __str__(self): + return "".join(node_tree_to_string(self.root)) + + def html(self, depth = 2, debug = False) -> HTML: + return HTML(node_tree_to_html(self.root, depth = depth, debug = debug)) + + def _repr_html_(self) -> str: + return node_tree_to_html(self.root, depth = 2) + + def __getitem__(self, args) -> 'CompressedTree': + key, value = args + for c in self.root.children: + if c.data.key == key and value in c.data.values: + data = dataclasses.replace(c.data, values = Enum((value,))) + return CompressedTree( + cache = self.cache, + root = dataclasses.replace(c, data = data) + ) + raise KeyError(f"Key {key} not found in children.") + + def collapse_children(self, node: "CompressedNode") -> "CompressedNode": + # First perform the collapse on the children + new_children = [self.collapse_children(child) for child in node.children] + + # Now take the set of new children and see if any have identical key, metadata and children + # the values may different and will be collapsed into a single node + identical_children = defaultdict(set) + for child in new_children: + identical_children[(child.data.key, child.data.metadata, child._children)].add(child) + + # Now go through and create new compressed nodes for any groups that need collapsing + new_children = [] + for (key, metadata, _children), child_set in identical_children.items(): + if len(child_set) > 1: + # Compress the children into a single node + assert all(isinstance(child.data.values, Enum) for child in child_set), "All children must have Enum values" + node_data = NodeData( + key = key, + metadata = frozendict(), # Todo: Implement metadata compression + values = Enum(tuple(v for child in child_set for v in child.data.values.values)), + ) + + # Add the node to the cache + id = type(self).add_to_cache(self.cache, node_data, _children) + else: + # If the group is size one just keep it + id = child_set.pop().id + + new_children.append(id) + + id = self.add_to_cache(self.cache, node.data, tuple(sorted(new_children))) + return self.cache[id] + + + def compress(self) -> 'CompressedTree': + return CompressedTree(cache = self.cache, root = self.collapse_children(self.root)) + + def lookup(self, selection : dict[str, str]): + nodes = [self.root] + for _ in range(1000): + found = False + current_node = nodes[-1] + for c in current_node.children: + if selection.get(c.data.key, None) in c.data.values: + if found: + raise RuntimeError("This tree is invalid, because it contains overlapping branches.") + nodes.append(c) + selection.pop(c.data.key) + found = True + + if not found: + return nodes + + raise RuntimeError("Maximum node searches exceeded, the tree contains a loop or something is buggy.") + + + + + # def reconstruct(self) -> Tree: + # def reconstruct_node(h : int) -> Tree: + # node = self.cache[h] + # dedup : dict[tuple[int, str], set[NodeId]] = defaultdict(set) + # for index in self.cache[h].children: + # child_node = self.cache[index] + # child_hash = hash(child_node.children) + # assert isinstance(child_node.values, Enum) + # dedup[(child_hash, child_node.key)].add(index) + + + # children = tuple( + # Tree(key = key, values = Enum(tuple(values)), + # children = tuple(reconstruct_node(i) for i in self.cache[next(indices)].children) + # ) + # for (_, key), indices in dedup.items() + # ) + + # return Tree( + # key = node.key, + # values = node.values, + # children = children, + # ) + # return reconstruct_node(self.root) \ No newline at end of file diff --git a/tree_compresser/python_src/tree_traverser/CompressedTree.py b/tree_compresser/python_src/tree_traverser/CompressedTree.py index 1422445..84d5c87 100644 --- a/tree_compresser/python_src/tree_traverser/CompressedTree.py +++ b/tree_compresser/python_src/tree_traverser/CompressedTree.py @@ -1,5 +1,6 @@ import json from collections import defaultdict +from dataclasses import asdict, dataclass from pathlib import Path Tree = dict[str, "Tree"] @@ -13,6 +14,11 @@ class RefcountedDict(dict[str, int]): def __hash__(self): return hash(tuple(sorted(self.items()))) +@dataclass +class JSONNode: + key: str + values: list[str] + children: list["JSONNode"] class CompressedTree(): """ @@ -101,6 +107,23 @@ class CompressedTree(): return {f"{key}={','.join(values)}" : reconstruct_node(h, depth=depth+1) for (h, key), values in dedup.items()} return reconstruct_node(from_node or self.root_hash, depth=0) + def to_json(self, max_depth=None, from_node=None) -> dict: + def reconstruct_node(h : int, depth : int) -> list[JSONNode]: + if max_depth is not None and depth > max_depth: + return {} + dedup : dict[tuple[int, str], set[str]] = defaultdict(set) + for k, h2 in self.cache[h].items(): + key, value = k.split("=") + dedup[(h2, key)].add(value) + + return [JSONNode( + key = key, + values = list(values), + children = reconstruct_node(h, depth=depth+1), + ) for (h, key), values in dedup.items()] + + return asdict(reconstruct_node(from_node or self.root_hash, depth=0)[0]) + def __init__(self, tree : Tree): self.cache = {} self.empty_hash = hash(RefcountedDict({})) @@ -139,8 +162,8 @@ class CompressedTree(): return list(loc.keys()) def multi_match(self, request : dict[str, list[str]], loc = None): - if not loc: return {"_END_" : {}} if loc is None: loc = self.tree + if loc == {}: return {"_END_" : {}} matches = {} for request_key, request_values in request.items(): for request_value in request_values: diff --git a/tree_compresser/python_src/tree_traverser/DataCubeTree.py b/tree_compresser/python_src/tree_traverser/DataCubeTree.py new file mode 100644 index 0000000..9e51802 --- /dev/null +++ b/tree_compresser/python_src/tree_traverser/DataCubeTree.py @@ -0,0 +1,267 @@ +import dataclasses +from dataclasses import dataclass, field +from typing import Any, Callable, Hashable, Literal, Mapping + +from frozendict import frozendict + +from .tree_formatters import HTML, node_tree_to_html, node_tree_to_string +from .value_types import DateRange, Enum, IntRange, TimeRange, Values + + +def values_from_json(obj) -> Values: + if isinstance(obj, list): + return Enum(tuple(obj)) + + match obj["dtype"]: + case "date": return DateRange(**obj) + case "time": return TimeRange(**obj) + case "int": return IntRange(**obj) + case _: raise ValueError(f"Unknown dtype {obj['dtype']}") + +# In practice use a frozendict +Metadata = Mapping[str, str | int | float | bool] + +@dataclass(frozen=True, eq=True, order=True) +class NodeData: + key: str + values: Values + metadata: dict[str, tuple[Hashable, ...]] = field(default_factory=frozendict, compare=False) + + def summary(self) -> str: + return f"{self.key}={self.values.summary()}" if self.key != "root" else "root" + +@dataclass(frozen=True, eq=True, order=True) +class Tree: + data: NodeData + children: tuple['Tree', ...] + + @property + def key(self) -> str: + return self.data.key + + @property + def values(self) -> Values: + return self.data.values + + @property + def metadata(self) -> frozendict[str, Any]: + return self.data.metadata + + + def summary(self) -> str: + return self.data.summary() + + @classmethod + def make(cls, key : str, values : Values, children, **kwargs) -> 'Tree': + return cls( + data = NodeData(key, values, metadata = kwargs.get("metadata", frozendict()) + ), + children = tuple(sorted(children)), + ) + + + @classmethod + def from_json(cls, json: dict) -> 'Tree': + def from_json(json: dict) -> Tree: + return Tree.make( + key=json["key"], + values=values_from_json(json["values"]), + metadata=json["metadata"] if "metadata" in json else {}, + children=tuple(from_json(c) for c in json["children"]) + ) + return from_json(json) + + @classmethod + def from_dict(cls, d: dict) -> 'Tree': + def from_dict(d: dict) -> tuple[Tree, ...]: + return tuple(Tree.make( + key=k.split("=")[0], + values=Enum(tuple(k.split("=")[1].split("/"))), + children=from_dict(children) + ) for k, children in d.items()) + + return Tree.make(key = "root", + values=Enum(("root",)), + children = from_dict(d)) + + @classmethod + def empty(cls) -> 'Tree': + return cls.make("root", Enum(("root",)), []) + + + def __str__(self): + return "".join(node_tree_to_string(node=self)) + + def html(self, depth = 2, collapse = True) -> HTML: + return HTML(node_tree_to_html(self, depth = depth, collapse = collapse)) + + def _repr_html_(self) -> str: + return node_tree_to_html(self, depth = 2, collapse = True) + + def __getitem__(self, args) -> 'Tree': + key, value = args + for c in self.children: + if c.key == key and value in c.values: + data = dataclasses.replace(c.data, values = Enum((value,))) + return dataclasses.replace(c, data = data) + raise KeyError(f"Key {key} not found in children of {self.key}") + + + def print(self, depth = None): + print("".join(cc for c in self.children for cc in node_tree_to_string(node=c, depth = depth))) + + def transform(self, func: 'Callable[[Tree], Tree | list[Tree]]') -> 'Tree': + """ + Call a function on every node of the tree, return one or more nodes. + If multiple nodes are returned they each get a copy of the (transformed) children of the original node. + Any changes to the children of a node will be ignored. + """ + def transform(node: Tree) -> list[Tree]: + children = [cc for c in node.children for cc in transform(c)] + new_nodes = func(node) + if isinstance(new_nodes, Tree): + new_nodes = [new_nodes] + + return [dataclasses.replace(new_node, children = children) + for new_node in new_nodes] + + children = tuple(cc for c in self.children for cc in transform(c)) + return dataclasses.replace(self, children = children) + + def guess_datatypes(self) -> 'Tree': + def guess_datatypes(node: Tree) -> list[Tree]: + # Try to convert enum values into more structured types + children = tuple(cc for c in node.children for cc in guess_datatypes(c)) + + if isinstance(node.values, Enum): + match node.key: + case "time": range_class = TimeRange + case "date": range_class = DateRange + case _: range_class = None + + if range_class is not None: + return [ + dataclasses.replace(node, values = range, children = children) + for range in range_class.from_strings(node.values.values) + ] + return [dataclasses.replace(node, children = children)] + + children = tuple(cc for c in self.children for cc in guess_datatypes(c)) + return dataclasses.replace(self, children = children) + + + def select(self, selection : dict[str, str | list[str]], mode: Literal["strict", "relaxed"] = "relaxed") -> 'Tree': + # make all values lists + selection = {k : v if isinstance(v, list) else [v] for k,v in selection.items()} + + def not_none(xs): return tuple(x for x in xs if x is not None) + + def select(node: Tree) -> Tree | None: + # Check if the key is specified in the selection + if node.key not in selection: + if mode == "strict": + return None + return dataclasses.replace(node, children = not_none(select(c) for c in node.children)) + + # If the key is specified, check if any of the values match + values = Enum(tuple(c for c in selection[node.key] if c in node.values)) + + if not values: + return None + + return dataclasses.replace(node, values = values, children = not_none(select(c) for c in node.children)) + + return dataclasses.replace(self, children = not_none(select(c) for c in self.children)) + + + @staticmethod + def _insert(position: "Tree", identifier : list[tuple[str, list[str]]]): + """ + This algorithm goes as follows: + We're at a particular node in the tree, and we have a list of key-values pairs that we want to insert. + We take the first key values pair + key, values = identifier.pop(0) + + The general idea is to insert key, values into the current node and use recursion to handle the rest of the identifier. + + We have two sources of values with possible overlap. The values to insert and the values attached to the children of this node. + For each value coming from either source we put it in one of three categories: + 1) Values that exist only in the already existing child. (Coming exclusively from position.children) + 2) Values that exist in both a child and the new values. + 3) Values that exist only in the new values. + + + Thus we add the values to insert to a set, and loop over the children. + For each child we partition its values into the three categories. + + For 1) we create a new child node with the key, reduced set of values and the same children. + For 2) + Create a new child node with the key, and the values in group 2 + Recurse to compute the children + + Once we have finished looping over children we know all the values left over came exclusively from the new values. + So we: + Create a new node with these values. + Recurse to compute the children + + Finally we return the node with all these new children. + """ + if not identifier: + return position + + key, values = identifier.pop(0) + # print(f"Inserting {key}={values} into {position.summary()}") + + # Determine which children have this key + possible_children = {c : [] for c in position.children if c.key == key} + entirely_new_values = [] + + # For each value check it is already in one of the children + for v in values: + for c in possible_children: + if v in c.values: + possible_children[c].append(v) + break + else: # only executed if the loop did not break + # If none of the children have this value, add it to the new child pile + entirely_new_values.append(v) + + # d = {p.summary() : v for p, v in possible_children.items()} + # print(f" {d} new_values={entirely_new_values}") + + new_children = [] + for c, affected in possible_children.items(): + if not affected: + new_children.append(c) + continue + + unaffected = [x for x in c.values if x not in affected] + if unaffected: + unaffected_node = Tree.make(c.key, Enum(tuple(unaffected)), c.children) + new_children.append(unaffected_node) # Add the unaffected part of this child + + if affected: # This check is not technically necessary, but it makes the code more readable + new_node = Tree.make(key, Enum(tuple(affected)), []) + new_node = Tree._insert(new_node, identifier) + new_children.append(new_node) # Add the affected part of this child + + # If there are any values not in any of the existing children, add them as a new child + if entirely_new_values: + new_node = Tree.make(key, Enum(tuple(entirely_new_values)), []) + new_children.append(Tree._insert(new_node, identifier)) + + return Tree.make(position.key, position.values, new_children) + + def insert(self, identifier : dict[str, list[str]]) -> 'Tree': + insertion = [(k, v) for k, v in identifier.items()] + return Tree._insert(self, insertion) + + def to_list_of_cubes(self): + def to_list_of_cubes(node: Tree) -> list[list[Tree]]: + return [[node] + sub_cube for c in node.children for sub_cube in to_list_of_cubes(c)] + + return to_list_of_cubes(self) + + def info(self): + cubes = self.to_list_of_cubes() + print(f"Number of distinct paths: {len(cubes)}") \ No newline at end of file diff --git a/tree_compresser/python_src/tree_traverser/tree_formatters.py b/tree_compresser/python_src/tree_traverser/tree_formatters.py new file mode 100644 index 0000000..eb3f978 --- /dev/null +++ b/tree_compresser/python_src/tree_traverser/tree_formatters.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass +from typing import Iterable, Protocol, Sequence, runtime_checkable + + +@runtime_checkable +class TreeLike(Protocol): + @property + def children(self) -> Sequence["TreeLike"]: ... # Supports indexing like node.children[i] + + def summary(self, **kwargs) -> str: ... + +@dataclass(frozen=True) +class HTML(): + html: str + def _repr_html_(self): + return self.html + +def summarize_node(node: TreeLike, collapse = False, **kwargs) -> tuple[str, TreeLike]: + """ + Extracts a summarized representation of the node while collapsing single-child paths. + Returns the summary string and the last node in the chain that has multiple children. + """ + summaries = [] + + while True: + summary = node.summary(**kwargs) + if len(summary) > 50: + summary = summary[:50] + "..." + summaries.append(summary) + if not collapse: + break + + # Move down if there's exactly one child, otherwise stop + if len(node.children) != 1: + break + node = node.children[0] + + return ", ".join(summaries), node + +def node_tree_to_string(node : TreeLike, prefix : str = "", depth = None) -> Iterable[str]: + summary, node = summarize_node(node) + + if depth is not None and depth <= 0: + yield summary + " - ...\n" + return + # Special case for nodes with only a single child, this makes the printed representation more compact + elif len(node.children) == 1: + yield summary + ", " + yield from node_tree_to_string(node.children[0], prefix, depth = depth) + return + else: + yield summary + "\n" + + for index, child in enumerate(node.children): + connector = "└── " if index == len(node.children) - 1 else "├── " + yield prefix + connector + extension = " " if index == len(node.children) - 1 else "│ " + yield from node_tree_to_string(child, prefix + extension, depth = depth - 1 if depth is not None else None) + +def _node_tree_to_html(node : TreeLike, prefix : str = "", depth = 1, connector = "", **kwargs) -> Iterable[str]: + summary, node = summarize_node(node, **kwargs) + + if len(node.children) == 0: + yield f'{connector}{summary}' + return + else: + open = "open" if depth > 0 else "" + yield f"
{connector}{summary}" + + for index, child in enumerate(node.children): + connector = "└── " if index == len(node.children) - 1 else "├── " + extension = " " if index == len(node.children) - 1 else "│ " + yield from _node_tree_to_html(child, prefix + extension, depth = depth - 1, connector = prefix+connector, **kwargs) + yield "
" + +def node_tree_to_html(node : TreeLike, depth = 1, **kwargs) -> str: + css = """ + + + """ + nodes = "".join(_node_tree_to_html(node=node, depth=depth, **kwargs)) + return f"{css}
{nodes}
" \ No newline at end of file diff --git a/tree_compresser/python_src/tree_traverser/trie.py b/tree_compresser/python_src/tree_traverser/trie.py new file mode 100644 index 0000000..5610eda --- /dev/null +++ b/tree_compresser/python_src/tree_traverser/trie.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, field + +character = str + +@dataclass(unsafe_hash=True) +class TrieNode(): + parent: "TrieNode | None" + parent_char: character + children: dict[character, "TrieNode"] = field(default_factory=dict) + + +@dataclass +class Trie: + root: TrieNode = field(default_factory=lambda: TrieNode(None, "")) + reverse_lookup: dict[int, TrieNode] = field(default_factory=dict) + + def insert(self, word: str): + node = self.root + for char in word: + if char not in node.children: + new_node = TrieNode(node, char) + node.children[char] = new_node + + node = node.children[char] + + n_id = id(node) + if n_id not in self.reverse_lookup: + self.reverse_lookup[n_id] = node + + return n_id + + def lookup_by_id(self, n_id: int): + leaf_node = self.reverse_lookup[n_id] + string = [] + while leaf_node.parent is not None: + string.append(leaf_node.parent_char) + leaf_node = leaf_node.parent + + return "".join(reversed(string)) + diff --git a/tree_compresser/python_src/tree_traverser/value_types.py b/tree_compresser/python_src/tree_traverser/value_types.py new file mode 100644 index 0000000..bae29f7 --- /dev/null +++ b/tree_compresser/python_src/tree_traverser/value_types.py @@ -0,0 +1,214 @@ +import dataclasses +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import date, datetime, timedelta +from typing import Any, Iterable, Literal + + +@dataclass(frozen=True) +class Values(ABC): + @abstractmethod + def summary(self) -> str: + pass + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def __contains__(self, value: Any) -> bool: + pass + + @abstractmethod + def from_strings(self, values: Iterable[str]) -> list['Values']: + pass + +@dataclass(frozen=True, order=True) +class Enum(Values): + """ + The simplest kind of key value is just a list of strings. + summary -> string1/string2/string.... + """ + values: tuple[Any, ...] + + def __post_init__(self): + assert isinstance(self.values, tuple) + + def __iter__(self): + return iter(self.values) + + def __len__(self) -> int: + return len(self.values) + def summary(self) -> str: + return '/'.join(map(str, sorted(self.values))) + def __contains__(self, value: Any) -> bool: + return value in self.values + def from_strings(self, values: Iterable[str]) -> list['Values']: + return [Enum(tuple(values))] + +@dataclass(frozen=True) +class Range(Values, ABC): + dtype: str = dataclasses.field(kw_only=True) + +@dataclass(frozen=True) +class DateRange(Range): + start: date + end: date + step: timedelta + dtype: Literal["date"] = dataclasses.field(kw_only=True, default="date") + + @classmethod + def from_strings(self, values: Iterable[str]) -> list['DateRange']: + dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values]) + if len(dates) < 2: + return [DateRange( + start=dates[0], + end=dates[0], + step=timedelta(days=0) + )] + + ranges = [] + current_range, dates = [dates[0],], dates[1:] + while len(dates) > 1: + if dates[0] - current_range[-1] == timedelta(days=1): + current_range.append(dates.pop(0)) + + elif len(current_range) == 1: + ranges.append(DateRange( + start=current_range[0], + end=current_range[0], + step=timedelta(days=0) + )) + current_range = [dates.pop(0),] + + else: + ranges.append(DateRange( + start=current_range[0], + end=current_range[-1], + step=timedelta(days=1) + )) + current_range = [dates.pop(0),] + return ranges + + def __contains__(self, value: Any) -> bool: + v = datetime.strptime(value, "%Y%m%d").date() + return self.start <= v <= self.end and (v - self.start) % self.step == 0 + + + def __len__(self) -> int: + return (self.end - self.start) // self.step + + def summary(self) -> str: + def fmt(d): return d.strftime("%Y%m%d") + if self.step == timedelta(days=0): + return f"{fmt(self.start)}" + if self.step == timedelta(days=1): + return f"{fmt(self.start)}/to/{fmt(self.end)}" + + return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step // timedelta(days=1)}" + +@dataclass(frozen=True) +class TimeRange(Range): + start: int + end: int + step: int + dtype: Literal["time"] = dataclasses.field(kw_only=True, default="time") + + @classmethod + def from_strings(self, values: Iterable[str]) -> list['TimeRange']: + if len(values) == 0: return [] + + times = sorted([int(v) for v in values]) + if len(times) < 2: + return [TimeRange( + start=times[0], + end=times[0], + step=100 + )] + + ranges = [] + current_range, times = [times[0],], times[1:] + while len(times) > 1: + if times[0] - current_range[-1] == 1: + current_range.append(times.pop(0)) + + elif len(current_range) == 1: + ranges.append(TimeRange( + start=current_range[0], + end=current_range[0], + step=0 + )) + current_range = [times.pop(0),] + + else: + ranges.append(TimeRange( + start=current_range[0], + end=current_range[-1], + step=1 + )) + current_range = [times.pop(0),] + return ranges + + def __len__(self) -> int: + return (self.end - self.start) // self.step + + def summary(self) -> str: + def fmt(d): return f"{d:04d}" + if self.step == 0: + return f"{fmt(self.start)}" + return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}" + + def __contains__(self, value: Any) -> bool: + v = int(value) + return self.start <= v <= self.end and (v - self.start) % self.step == 0 + +@dataclass(frozen=True) +class IntRange(Range): + start: int + end: int + step: int + dtype: Literal["int"] = dataclasses.field(kw_only=True, default="int") + + def __len__(self) -> int: + return (self.end - self.start) // self.step + + def summary(self) -> str: + def fmt(d): return d.strftime("%Y%m%d") + return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}" + + def __contains__(self, value: Any) -> bool: + v = int(value) + return self.start <= v <= self.end and (v - self.start) % self.step == 0 + + @classmethod + def from_strings(self, values: Iterable[str]) -> list['IntRange']: + if len(values) == 0: return [] + ints = sorted([int(v) for v in values]) + if len(ints) < 2: + return [IntRange( + start=ints[0], + end=ints[0], + step=0 + )] + + ranges = [] + current_range, ints = [ints[0],], ints[1:] + while len(ints) > 1: + if ints[0] - current_range[-1] == 1: + current_range.append(ints.pop(0)) + + elif len(current_range) == 1: + ranges.append(IntRange( + start=current_range[0], + end=current_range[0], + step=0 + )) + current_range = [ints.pop(0),] + + else: + ranges.append(IntRange( + start=current_range[0], + end=current_range[-1], + step=1 + )) + current_range = [ints.pop(0),] + return ranges \ No newline at end of file diff --git a/tree_compresser/tests/open_climate_dt.py b/tree_compresser/tests/open_climate_dt.py index b5e9c18..1171112 100644 --- a/tree_compresser/tests/open_climate_dt.py +++ b/tree_compresser/tests/open_climate_dt.py @@ -3,15 +3,15 @@ from pathlib import Path from tree_traverser import CompressedTree -data_path = Path("/home/eouser/qubed/config/climate-dt/compressed_tree.json") +data_path = Path("./config/climate-dt/compressed_tree.json") # Print size of file print(f"climate dt compressed tree: {data_path.stat().st_size // 1e6:.1f} MB") print("Opening json file") compressed_tree = CompressedTree.load(data_path) -print(compressed_tree.reconstruct_compressed_ecmwf_style()) +print(compressed_tree.to_json()) -# print("Outputting compressed tree ecmwf style") -# with open("data/compressed_tree_climate_dt_ecmwf_style.json", "w") as f: -# json.dump(compressed_tree.reconstruct_compressed_ecmwf_style(), f) +print("Outputting compressed tree ecmwf style") +with open("config/climate-dt/new_format.json", "w") as f: + json.dump(compressed_tree.to_json(), f)