From df8ea6c2f9e353e03d396f1c91281252668ec562 Mon Sep 17 00:00:00 2001 From: Tom Hodson Date: Thu, 21 Nov 2024 13:57:52 +0000 Subject: [PATCH] add tree_compresser --- tree_compresser/pyproject.toml | 13 + .../src/TreeTraverser/CompressedTree.py | 305 ++++++++++++++ .../src/TreeTraverser/fdb_schema/__init__.py | 1 + .../fdb_schema/fdb_schema_parser.py | 375 ++++++++++++++++++ .../src/TreeTraverser/fdb_schema/fdb_types.py | 83 ++++ 5 files changed, 777 insertions(+) create mode 100644 tree_compresser/pyproject.toml create mode 100644 tree_compresser/src/TreeTraverser/CompressedTree.py create mode 100644 tree_compresser/src/TreeTraverser/fdb_schema/__init__.py create mode 100644 tree_compresser/src/TreeTraverser/fdb_schema/fdb_schema_parser.py create mode 100644 tree_compresser/src/TreeTraverser/fdb_schema/fdb_types.py diff --git a/tree_compresser/pyproject.toml b/tree_compresser/pyproject.toml new file mode 100644 index 0000000..391b59f --- /dev/null +++ b/tree_compresser/pyproject.toml @@ -0,0 +1,13 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "TreeTraverser" +description = "Tools to work with compressed Datacubes and Trees" +dynamic = ["version"] +dependencies = [ + "fastapi", + "pe" +] + diff --git a/tree_compresser/src/TreeTraverser/CompressedTree.py b/tree_compresser/src/TreeTraverser/CompressedTree.py new file mode 100644 index 0000000..5f8c095 --- /dev/null +++ b/tree_compresser/src/TreeTraverser/CompressedTree.py @@ -0,0 +1,305 @@ +import json +from collections import defaultdict +from typing import TypeVar +from pathlib import Path + +Tree = dict[str, "Tree"] + +class RefcountedDict(dict[str, int]): + refcount: int = 1 + + def __repr__(self): + return f"RefcountedDict(refcount={self.refcount}, {super().__repr__()})" + + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + +class CompressedTree(): + """ + A implementation of a compressed tree that supports lookup, insertion, deletion and caching. + The caching means that identical subtrees are stored only once, saving memory + This is implemented internal by storing all subtrees in a global hash table + + """ + cache: dict[int, RefcountedDict] + tree: RefcountedDict + + def _add_to_cache(self, level : RefcountedDict) -> int: + "Add a level {key -> hash} to the cache" + h = hash(level) + if h not in self.cache: + # Increase refcounts of the child nodes + for child_h in level.values(): + self.cache[child_h].refcount += 1 + self.cache[h] = RefcountedDict(level) + else: + self.cache[h].refcount += 1 + return h + + def _replace_in_cache(self, old_h, level : RefcountedDict) -> int: + """ + Replace the object at old_h with a different object level + If the objects this is a no-op + """ + # Start by adding the new object to the cache + new_h = self._add_to_cache(level) + + # Now check if the old object needs to be garbage collected + self._decrease_refcount(old_h) + + return new_h + + def _decrease_refcount(self, h : int): + self.cache[h].refcount -= 1 + if self.cache[h].refcount == 0: + # Recursively decrease refcounts of child nodes + for child_h in self.cache[h].values(): + self._decrease_refcount(child_h) + del self.cache[h] + + def cache_tree(self, tree : Tree) -> int: + "Insert the given tree (dictonary of dictionaries) (all it's children, recursively) into the hash table and return the hash key" + level = RefcountedDict({k : self.cache_tree(v) for k, v in tree.items()}) + return self._add_to_cache(level) + + + def _cache_path(self, path : list[str]) -> int: + "Treat path = [x, y, z...] like {x : {y : {z : ...}}} and cache that" + if not path: + return self.empty_hash + k, *rest = path + return self._add_to_cache(RefcountedDict({k : self._cache_path(rest)})) + + def reconstruct(self) -> dict[str, dict]: + "Reconstruct the tree as a normal nested dictionary" + def reconstruct_node(h : int) -> dict[str, dict]: + return {k : reconstruct_node(v) for k, v in self.cache[h].items()} + return reconstruct_node(self.root_hash) + + def reconstruct_compressed(self) -> dict[str, dict]: + "Reconstruct the tree as a normal nested dictionary" + def reconstruct_node(h : int) -> dict[str, dict]: + dedup : dict[int, set[str]] = defaultdict(set) + for k, h2 in self.cache[h].items(): + dedup[h2].add(k) + + return {"/".join(keys) : reconstruct_node(h) for h, keys in dedup.items()} + return reconstruct_node(self.root_hash) + + def reconstruct_compressed_ecmwf_style(self) -> dict[str, dict]: + "Reconstruct the tree as a normal nested dictionary" + def reconstruct_node(h : int) -> dict[str, dict]: + dedup : dict[tuple[int, str], set[str]] = defaultdict(set) + for k, h2 in self.cache[h].items(): + key, value = k.split("=") + dedup[(h2, key)].add(value) + + + + return {f"{key}={','.join(values)}" : reconstruct_node(h) for (h, key), values in dedup.items()} + return reconstruct_node(self.root_hash) + + def __init__(self, tree : Tree): + self.cache = {} + self.empty_hash = hash(RefcountedDict({})) + + # Recursively cache the tree + self.root_hash = self.cache_tree(tree) + + # Keep a reference to the root of the tree + self.tree = self.cache[self.root_hash] + + + def lookup(self, keys : tuple[str, ...]) -> tuple[bool, tuple[str, ...]]: + """ + Lookup a subtree in the tree + Returns success, path + if success == True it means the path got to the bottom of the tree and path will be equal to keys + if success == False, path will holds the keys that were found + """ + loc = self.tree + for i, key in enumerate(keys): + if key in loc: + h = loc[key] # get the hash of the subtree + loc = self.cache[h] # get the subtree + else: + return False, keys[:i] + return True, keys + + def keys(self, keys : tuple[str, ...] = ()) -> list[str] | None: + loc = self.tree + for i, key in enumerate(keys): + if key in loc: + h = loc[key] # get the hash of the subtree + loc = self.cache[h] # get the subtree + else: + return None + return list(loc.keys()) + + def multi_match(self, request : dict[str, list[str]], loc = None): + if not loc: return {"_END_" : {}} + if loc is None: loc = self.tree + matches = {} + for request_key, request_values in request.items(): + for request_value in request_values: + meta_key = f"{request_key}={request_value}" + if meta_key in loc: + new_loc = self.cache[loc[meta_key]] + matches[meta_key] = self.multi_match(request, new_loc) + + if not matches: return {k : {} for k in loc.items()} + return matches + + + def _insert(self, old_h : int, tree: RefcountedDict, keys : tuple[str, ...]) -> int: + "Insert keys in the subtree and return the new hash of the subtree" + key, *rest = keys + assert old_h in self.cache + + # Adding a new branch to the tree + if key not in tree: + new_tree = RefcountedDict(tree | {key : self._cache_path(rest)}) + + else: + # Make a copy of the tree and update the subtree + new_tree = RefcountedDict(tree.copy()) + subtree_h = tree[key] + subtree = self.cache[subtree_h] + new_tree[key] = self._insert(subtree_h, subtree, tuple(rest)) + + # no-op if the hash hasn't changed + new_h = self._replace_in_cache(old_h, new_tree) + return new_h + + + def insert(self, keys : tuple[str, ...]): + """ + Insert a new branch into the compressed tree + """ + already_there, path = self.lookup(keys) + if already_there: + return + # Update the tree + self.root_hash = self._insert(self.root_hash, self.tree, keys) + self.tree = self.cache[self.root_hash] + + def insert_tree(self, subtree: Tree): + """ + Insert a whole tree into the compressed tree. + """ + self.root_hash = self._insert_tree(self.root_hash, self.tree, subtree) + self.tree = self.cache[self.root_hash] + + def _insert_tree(self, old_h: int, tree: RefcountedDict, subtree: Tree) -> int: + """ + Recursively insert a subtree into the compressed tree and return the new hash. + """ + assert old_h in self.cache + + # Make a copy of the tree to avoid modifying shared structures + new_tree = RefcountedDict(tree.copy()) + for key, sub_subtree in subtree.items(): + if key not in tree: + # Key is not in current tree, add the subtree + # Cache the subtree rooted at sub_subtree + subtree_h = self.cache_tree(sub_subtree) + new_tree[key] = subtree_h + else: + # Key is in tree, need to recursively merge + # Get the hash and subtree from the current tree + child_h = tree[key] + child_tree = self.cache[child_h] + # Recursively merge + new_child_h = self._insert_tree(child_h, child_tree, sub_subtree) + new_tree[key] = new_child_h + + # Replace the old hash with the new one in the cache + new_h = self._replace_in_cache(old_h, new_tree) + return new_h + + def save(self, path : Path): + "Save the compressed tree to a file" + with open(path, "w") as f: + json.dump({ + "cache" : {k : {"refcount" : v.refcount, "dict" : v} for k, v in self.cache.items()}, + "root_hash": self.root_hash + }, f) + + @classmethod + def load(cls, path : Path) -> "CompressedTree": + "Load the compressed tree from a file" + with open(path) as f: + data = json.load(f) + return cls.from_json(data) + + + @classmethod + def from_json(cls, data : dict) -> "CompressedTree": + c = CompressedTree({}) + c.cache = {} + for k, v in data["cache"].items(): + c.cache[int(k)] = RefcountedDict(v["dict"]) + c.cache[int(k)].refcount = v["refcount"] + + c.root_hash = data["root_hash"] + c.tree = c.cache[c.root_hash] + return c + + +if __name__ == "__main__": + original_tree = { + "a": { + "b1": { + "c": {} + }, + "b2" : { + "c": {} + }, + "b3*": { + "c*": {} + } + } + } + + c_tree = CompressedTree(original_tree) + + assert c_tree.lookup(("a", "b1", "c")) == (True, ("a", "b1", "c")) + assert c_tree.lookup(("a", "b1", "d")) == (False, ("a", "b1")) + + print(json.dumps(c_tree.reconstruct_compressed(), indent = 4)) + + assert c_tree.reconstruct() == original_tree + + c_tree.insert(("a", "b1", "d")) + c_tree.insert(("a", "b2", "d")) + print(json.dumps(c_tree.reconstruct(), indent = 4)) + + print(json.dumps(c_tree.reconstruct_compressed(), indent = 4)) + print(c_tree.cache) + + # test round trip + assert CompressedTree(original_tree).reconstruct() == original_tree + + # test adding a key + added_keys_tree = { + "a": { + "b1": { + "c": {} + }, + "b2" : { + "c": {}, + "d" : {} + }, + "b3*": { + "c*": {}, + "d*": {} + } + } + } + c_tree = CompressedTree(original_tree) + c_tree.insert(("a", "b2", "d")) + c_tree.insert(("a", "b3*", "d*")) + assert c_tree.reconstruct() == added_keys_tree + + print(c_tree.reconstruct_compressed()) \ No newline at end of file diff --git a/tree_compresser/src/TreeTraverser/fdb_schema/__init__.py b/tree_compresser/src/TreeTraverser/fdb_schema/__init__.py new file mode 100644 index 0000000..56160a5 --- /dev/null +++ b/tree_compresser/src/TreeTraverser/fdb_schema/__init__.py @@ -0,0 +1 @@ +from .fdb_schema_parser import FDBSchema, FDBSchemaFile, KeySpec, Key diff --git a/tree_compresser/src/TreeTraverser/fdb_schema/fdb_schema_parser.py b/tree_compresser/src/TreeTraverser/fdb_schema/fdb_schema_parser.py new file mode 100644 index 0000000..852c1bb --- /dev/null +++ b/tree_compresser/src/TreeTraverser/fdb_schema/fdb_schema_parser.py @@ -0,0 +1,375 @@ +import dataclasses +import json +from dataclasses import dataclass, field +from typing import Any + +import pe +from pe.actions import Pack +from pe.operators import Class, Star + +from .fdb_types import FDB_type_to_implementation, FDBType + + +@dataclass(frozen=True) +class KeySpec: + """ + Represents the specification of a single key in an FDB schema file. For example in + ``` + [ class, expver, stream=lwda, date, time, domain? + [ type=ofb/mfb/oai + [ obsgroup, reportype ]]] + ``` + class, expver, type=ofdb/mfb/oai etc are the KeySpecs + + These can have additional information such as: flags like `domain?`, allowed values like `type=ofb/mfb/oai` + or specify type information with `date: ClimateMonthly` + + """ + + key: str + type: FDBType = field(default_factory=FDBType) + flag: str | None = None + values: tuple = field(default_factory=tuple) + comment: str = "" + + def __repr__(self): + repr = self.key + if self.flag: + repr += self.flag + # if self.type: + # repr += f":{self.type}" + if self.values: + repr += "=" + "/".join(self.values) + return repr + + def matches(self, key, value): + # Sanity check! + if self.key != key: + return False + + # Some keys have a set of allowed values type=ofb/mfb/oai + if self.values: + if value not in self.values: + return False + + # Check the formatting of values like Time or Date + if self.type and not self.type.validate(value): + return False + + return True + + def is_optional(self): + if self.flag is None: + return False + return "?" in self.flag + + def is_allable(self): + if self.flag is None: + return False + return "*" in self.flag + + +@dataclass(frozen=True) +class Comment: + "Represents a comment node in the schema" + + value: str + + +@dataclass(frozen=True) +class FDBSchemaTypeDef: + "Mapping between FDB schema key names and FDB Schema Types, i.e expver is of type Expver" + + key: str + type: str + + +# This is the schema grammar written in PEG format +fdb_schema = pe.compile( + r""" + FDB < Line+ EOF + Line < Schema / Comment / TypeDef / empty + + # Comments + Comment <- "#" ~non_eol* + non_eol <- [\x09\x20-\x7F] / non_ascii + non_ascii <- [\x80-\uD7FF\uE000-\U0010FFFF] + + # Default Type Definitions + TypeDef < String ":" String ";" + + # Schemas are the main attraction + # They're a tree of KeySpecs. + Schema < "[" KeySpecs (","? Schema)* "]" + + # KeySpecs can be just a name i.e expver + # Can also have a type expver:int + # Or a flag expver? + # Or values expver=xxx + KeySpecs < KeySpec_ws ("," KeySpec_ws)* + KeySpec_ws < KeySpec + KeySpec <- key:String (flag:Flag)? (type:Type)? (values:Values)? ([ ]* comment:Comment)? + Flag <- ~("?" / "-" / "*") + Type <- ":" [ ]* String + Values <- "=" Value ("/" Value)* + + # Low level stuff + Value <- ~([-a-zA-Z0-9_]+) + String <- ~([a-zA-Z0-9_]+) + EOF <- !. + empty <- "" + """, + actions={ + "Schema": Pack(tuple), + "KeySpec": KeySpec, + "Values": Pack(tuple), + "Comment": Comment, + "TypeDef": FDBSchemaTypeDef, + }, + ignore=Star(Class("\t\f\r\n ")), + # flags=pe.DEBUG, +) + + +def post_process(entries): + "Take the raw output from the PEG parser and split it into type definitions and schema entries." + typedefs = {} + schemas = [] + for entry in entries: + match entry: + case c if isinstance(c, Comment): + pass + case t if isinstance(t, FDBSchemaTypeDef): + typedefs[t.key] = t.type + case s if isinstance(s, tuple): + schemas.append(s) + case _: + raise ValueError + return typedefs, tuple(schemas) + + +def determine_types(types, node): + "Recursively walk a schema tree and insert the type information." + if isinstance(node, tuple): + return [determine_types(types, n) for n in node] + return dataclasses.replace(node, type=types.get(node.key, FDBType())) + + +@dataclass +class Key: + key: str + value: Any + key_spec: KeySpec + reason: str + + def str_value(self): + return self.key_spec.type.format(self.value) + + def __bool__(self): + return self.reason in {"Matches", "Skipped", "Select All"} + + def emoji(self): + return {"Matches": "✅", "Skipped": "⏭️", "Select All": "★"}.get( + self.reason, "❌" + ) + + def info(self): + return f"{self.emoji()} {self.key:<12}= {str(self.value):<12} ({self.key_spec}) {self.reason if not self else ''}" + + def __repr__(self): + return f"{self.key}={self.key_spec.type.format(self.value)}" + + def as_json(self): + return dict( + key=self.key, + value=self.str_value(), + reason=self.reason, + ) + + +class FDBSchema: + """ + Represents a parsed FDB Schema file. + Has methods to validate and convert request dictionaries to a mars request form with validation and type information. + """ + + def __init__(self, string, defaults: dict[str, str] = {}): + """ + 1. Use a PEG parser on a schema string, + 2. Separate the output into schemas and typedefs + 3. Insert any concrete implementations of types from fdb_types.py defaulting to generic string type + 4. Walk the schema tree and annotate it with type information. + """ + m = fdb_schema.match(string) + g = list(m.groups()) + self._str_types, schemas = post_process(g) + self.types = { + key: FDB_type_to_implementation[type] + for key, type in self._str_types.items() + } + self.schemas = determine_types(self.types, schemas) + self.defaults = defaults + + def __repr__(self): + return json.dumps( + dict(schemas=self.schemas, defaults=self.defaults), indent=4, default=repr + ) + + @classmethod + def consume_key( + cls, key_spec: KeySpec, request: dict[str, Any] + ) -> Key: + key = key_spec.key + try: + value = request[key] + except KeyError: + if key_spec.is_optional(): + return Key(key_spec.key, "", key_spec, "Skipped") + if key_spec.is_allable(): + return Key(key_spec.key, "", key_spec, "Select All") + else: + return Key( + key_spec.key, "", key_spec, "Key Missing" + ) + + if key_spec.matches(key, value): + return Key( + key_spec.key, + key_spec.type.parse(value), + key_spec, + "Matches", + ) + else: + return Key( + key_spec.key, value, key_spec, "Incorrect Value" + ) + + @classmethod + def _DFS_match( + cls, tree: list, request: dict[str, Any] + ) -> tuple[bool | list, list[Key]]: + """Do a DFS on the schema tree, returning the deepest matching path + At each stage return whether we matched on this path, and the path itself. + + When traversing the tree there are three cases to consider: + 1. base case [] + 2. one schema [k, k, k, [k, k, k]] + 3. list of schemas [[k,k,k], [k,k,k], [k,k,k]] + """ + # Case 1: Base Case + if not tree: + return True, [] + + # Case 2: [k, k, k, [k, k, k]] + if isinstance(tree[0], KeySpec): + node, *tree = tree + # Check if this node is in the request + match_result = cls.consume_key(node, request) + + # If if isn't then terminate this path here + if not match_result: + return False, [match_result,] # fmt: skip + + # Otherwise continue walking the tree and return the best result + matched, path = cls._DFS_match(tree, request) + + # Don't put the key in the path if it's optional and we're skipping it. + if match_result.reason != "Skipped": + path = [match_result,] + path # fmt: skip + + return matched, path + + # Case 3: [[k, k, k], [k, k, k]] + branches = [] + for branch in tree: + matched, branch_path = cls._DFS_match(branch, request) + + # If this branch matches, terminate the DFS and use this. + if matched: + return branch, branch_path + else: + branches.append(branch_path) + + # If no branch matches, return the one with the deepest match + return False, max(branches, key=len) + + @classmethod + def _DFS_match_all( + cls, tree: list, request: dict[str, Any] + ) -> list[list[Key]]: + """Do a DFS on the schema tree, returning all matching paths or partial matches. + At each stage return all matching paths and the deepest partial matches. + + When traversing the tree there are three cases to consider: + 1. base case [] + 2. one schema [k, k, k, [k, k, k]] + 3. list of schemas [[k,k,k], [k,k,k], [k,k,k]] + """ + # Case 1: Base Case + if not tree: + return [[]] + + # Case 2: [k, k, k, [k, k, k]] + if isinstance(tree[0], KeySpec): + node, *tree = tree + # Check if this node is in the request + request_values = request.get(node.key, None) + + if request_values is None: + # If the key is not in the request, return a partial match with Key Missing + return [[Key(node.key, "", node, "Key Missing")]] + + # If the request value is a list, try to match each value + if isinstance(request_values, list): + all_matches = [] + for value in request_values: + match_result = cls.consume_key(node, {node.key: value}) + + if match_result: + sub_matches = cls._DFS_match_all(tree, request) + for match in sub_matches: + if match_result.reason != "Skipped": + match.insert(0, match_result) + all_matches.append(match) + + return all_matches if all_matches else [[Key(node.key, "", node, "No Match Found")]] + else: + # Handle a single value + match_result = cls.consume_key(node, request) + + # If it isn't then return a partial match with Key Missing + if not match_result: + return [[Key(node.key, "", node, "Key Missing")]] + + # Continue walking the tree and get all matches + all_matches = cls._DFS_match_all(tree, request) + + # Prepend the current match to all further matches + for match in all_matches: + if match_result.reason != "Skipped": + match.insert(0, match_result) + + return all_matches + + # Case 3: [[k, k, k], [k, k, k]] + all_branch_matches = [] + for branch in tree: + branch_matches = cls._DFS_match_all(branch, request) + all_branch_matches.extend(branch_matches) + + # Return all of the deepest partial matches or complete matches + return all_branch_matches + + def match_all(self, request: dict[str, Any]): + request = request | self.defaults + return self._DFS_match_all(self.schemas, request) + + def match(self, request: dict[str, Any]): + request = request | self.defaults + return self._DFS_match(self.schemas, request) + + +class FDBSchemaFile(FDBSchema): + def __init__(self, path: str): + with open(path, "r") as f: + return super().__init__(f.read()) diff --git a/tree_compresser/src/TreeTraverser/fdb_schema/fdb_types.py b/tree_compresser/src/TreeTraverser/fdb_schema/fdb_types.py new file mode 100644 index 0000000..05093db --- /dev/null +++ b/tree_compresser/src/TreeTraverser/fdb_schema/fdb_types.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass +from typing import Any +import re +from collections import defaultdict +from datetime import datetime, date, time + + +@dataclass(repr=False) +class FDBType: + """ + Holds information about how to format and validate a given FDB Schema type like Time or Expver + This base type represents a string and does no validation or formatting. It's the default type. + """ + + name: str = "String" + + def __repr__(self) -> str: + return self.name + + def validate(self, s: Any) -> bool: + try: + self.parse(s) + return True + except (ValueError, AssertionError): + return False + + def format(self, s: Any) -> str: + return str(s).lower() + + def parse(self, s: str) -> Any: + return s + + +@dataclass(repr=False) +class Expver_FDBType(FDBType): + name: str = "Expver" + + def parse(self, s: str) -> str: + assert bool(re.match(".{4}", s)) + return s + + +@dataclass(repr=False) +class Time_FDBType(FDBType): + name: str = "Time" + time_format = "%H%M" + + def format(self, t: time) -> str: + return t.strftime(self.time_format) + + def parse(self, s: datetime | str | int) -> time: + if isinstance(s, str): + assert len(s) == 4 + return datetime.strptime(s, self.time_format).time() + if isinstance(s, datetime): + return s.time() + return self.parse(f"{s:04}") + + +@dataclass(repr=False) +class Date_FDBType(FDBType): + name: str = "Date" + date_format: str = "%Y%m%d" + + def format(self, d: Any) -> str: + if isinstance(d, date): + return d.strftime(self.date_format) + if isinstance(d, int): + return f"{d:08}" + else: + return d + + def parse(self, s: datetime | str | int) -> date: + if isinstance(s, str): + return datetime.strptime(s, self.date_format).date() + elif isinstance(s, datetime): + return s.date() + return self.parse(f"{s:08}") + + +FDB_type_to_implementation = defaultdict(lambda: FDBType()) | { + cls.name: cls() for cls in [Expver_FDBType, Time_FDBType, Date_FDBType] +}