add tree_compresser

This commit is contained in:
Tom Hodson 2024-11-21 13:57:52 +00:00
parent 50d86c77ec
commit df8ea6c2f9
5 changed files with 777 additions and 0 deletions

View File

@ -0,0 +1,13 @@
[build-system]
requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "TreeTraverser"
description = "Tools to work with compressed Datacubes and Trees"
dynamic = ["version"]
dependencies = [
"fastapi",
"pe"
]

View File

@ -0,0 +1,305 @@
import json
from collections import defaultdict
from typing import TypeVar
from pathlib import Path
Tree = dict[str, "Tree"]
class RefcountedDict(dict[str, int]):
refcount: int = 1
def __repr__(self):
return f"RefcountedDict(refcount={self.refcount}, {super().__repr__()})"
def __hash__(self):
return hash(tuple(sorted(self.items())))
class CompressedTree():
"""
A implementation of a compressed tree that supports lookup, insertion, deletion and caching.
The caching means that identical subtrees are stored only once, saving memory
This is implemented internal by storing all subtrees in a global hash table
"""
cache: dict[int, RefcountedDict]
tree: RefcountedDict
def _add_to_cache(self, level : RefcountedDict) -> int:
"Add a level {key -> hash} to the cache"
h = hash(level)
if h not in self.cache:
# Increase refcounts of the child nodes
for child_h in level.values():
self.cache[child_h].refcount += 1
self.cache[h] = RefcountedDict(level)
else:
self.cache[h].refcount += 1
return h
def _replace_in_cache(self, old_h, level : RefcountedDict) -> int:
"""
Replace the object at old_h with a different object level
If the objects this is a no-op
"""
# Start by adding the new object to the cache
new_h = self._add_to_cache(level)
# Now check if the old object needs to be garbage collected
self._decrease_refcount(old_h)
return new_h
def _decrease_refcount(self, h : int):
self.cache[h].refcount -= 1
if self.cache[h].refcount == 0:
# Recursively decrease refcounts of child nodes
for child_h in self.cache[h].values():
self._decrease_refcount(child_h)
del self.cache[h]
def cache_tree(self, tree : Tree) -> int:
"Insert the given tree (dictonary of dictionaries) (all it's children, recursively) into the hash table and return the hash key"
level = RefcountedDict({k : self.cache_tree(v) for k, v in tree.items()})
return self._add_to_cache(level)
def _cache_path(self, path : list[str]) -> int:
"Treat path = [x, y, z...] like {x : {y : {z : ...}}} and cache that"
if not path:
return self.empty_hash
k, *rest = path
return self._add_to_cache(RefcountedDict({k : self._cache_path(rest)}))
def reconstruct(self) -> dict[str, dict]:
"Reconstruct the tree as a normal nested dictionary"
def reconstruct_node(h : int) -> dict[str, dict]:
return {k : reconstruct_node(v) for k, v in self.cache[h].items()}
return reconstruct_node(self.root_hash)
def reconstruct_compressed(self) -> dict[str, dict]:
"Reconstruct the tree as a normal nested dictionary"
def reconstruct_node(h : int) -> dict[str, dict]:
dedup : dict[int, set[str]] = defaultdict(set)
for k, h2 in self.cache[h].items():
dedup[h2].add(k)
return {"/".join(keys) : reconstruct_node(h) for h, keys in dedup.items()}
return reconstruct_node(self.root_hash)
def reconstruct_compressed_ecmwf_style(self) -> dict[str, dict]:
"Reconstruct the tree as a normal nested dictionary"
def reconstruct_node(h : int) -> dict[str, dict]:
dedup : dict[tuple[int, str], set[str]] = defaultdict(set)
for k, h2 in self.cache[h].items():
key, value = k.split("=")
dedup[(h2, key)].add(value)
return {f"{key}={','.join(values)}" : reconstruct_node(h) for (h, key), values in dedup.items()}
return reconstruct_node(self.root_hash)
def __init__(self, tree : Tree):
self.cache = {}
self.empty_hash = hash(RefcountedDict({}))
# Recursively cache the tree
self.root_hash = self.cache_tree(tree)
# Keep a reference to the root of the tree
self.tree = self.cache[self.root_hash]
def lookup(self, keys : tuple[str, ...]) -> tuple[bool, tuple[str, ...]]:
"""
Lookup a subtree in the tree
Returns success, path
if success == True it means the path got to the bottom of the tree and path will be equal to keys
if success == False, path will holds the keys that were found
"""
loc = self.tree
for i, key in enumerate(keys):
if key in loc:
h = loc[key] # get the hash of the subtree
loc = self.cache[h] # get the subtree
else:
return False, keys[:i]
return True, keys
def keys(self, keys : tuple[str, ...] = ()) -> list[str] | None:
loc = self.tree
for i, key in enumerate(keys):
if key in loc:
h = loc[key] # get the hash of the subtree
loc = self.cache[h] # get the subtree
else:
return None
return list(loc.keys())
def multi_match(self, request : dict[str, list[str]], loc = None):
if not loc: return {"_END_" : {}}
if loc is None: loc = self.tree
matches = {}
for request_key, request_values in request.items():
for request_value in request_values:
meta_key = f"{request_key}={request_value}"
if meta_key in loc:
new_loc = self.cache[loc[meta_key]]
matches[meta_key] = self.multi_match(request, new_loc)
if not matches: return {k : {} for k in loc.items()}
return matches
def _insert(self, old_h : int, tree: RefcountedDict, keys : tuple[str, ...]) -> int:
"Insert keys in the subtree and return the new hash of the subtree"
key, *rest = keys
assert old_h in self.cache
# Adding a new branch to the tree
if key not in tree:
new_tree = RefcountedDict(tree | {key : self._cache_path(rest)})
else:
# Make a copy of the tree and update the subtree
new_tree = RefcountedDict(tree.copy())
subtree_h = tree[key]
subtree = self.cache[subtree_h]
new_tree[key] = self._insert(subtree_h, subtree, tuple(rest))
# no-op if the hash hasn't changed
new_h = self._replace_in_cache(old_h, new_tree)
return new_h
def insert(self, keys : tuple[str, ...]):
"""
Insert a new branch into the compressed tree
"""
already_there, path = self.lookup(keys)
if already_there:
return
# Update the tree
self.root_hash = self._insert(self.root_hash, self.tree, keys)
self.tree = self.cache[self.root_hash]
def insert_tree(self, subtree: Tree):
"""
Insert a whole tree into the compressed tree.
"""
self.root_hash = self._insert_tree(self.root_hash, self.tree, subtree)
self.tree = self.cache[self.root_hash]
def _insert_tree(self, old_h: int, tree: RefcountedDict, subtree: Tree) -> int:
"""
Recursively insert a subtree into the compressed tree and return the new hash.
"""
assert old_h in self.cache
# Make a copy of the tree to avoid modifying shared structures
new_tree = RefcountedDict(tree.copy())
for key, sub_subtree in subtree.items():
if key not in tree:
# Key is not in current tree, add the subtree
# Cache the subtree rooted at sub_subtree
subtree_h = self.cache_tree(sub_subtree)
new_tree[key] = subtree_h
else:
# Key is in tree, need to recursively merge
# Get the hash and subtree from the current tree
child_h = tree[key]
child_tree = self.cache[child_h]
# Recursively merge
new_child_h = self._insert_tree(child_h, child_tree, sub_subtree)
new_tree[key] = new_child_h
# Replace the old hash with the new one in the cache
new_h = self._replace_in_cache(old_h, new_tree)
return new_h
def save(self, path : Path):
"Save the compressed tree to a file"
with open(path, "w") as f:
json.dump({
"cache" : {k : {"refcount" : v.refcount, "dict" : v} for k, v in self.cache.items()},
"root_hash": self.root_hash
}, f)
@classmethod
def load(cls, path : Path) -> "CompressedTree":
"Load the compressed tree from a file"
with open(path) as f:
data = json.load(f)
return cls.from_json(data)
@classmethod
def from_json(cls, data : dict) -> "CompressedTree":
c = CompressedTree({})
c.cache = {}
for k, v in data["cache"].items():
c.cache[int(k)] = RefcountedDict(v["dict"])
c.cache[int(k)].refcount = v["refcount"]
c.root_hash = data["root_hash"]
c.tree = c.cache[c.root_hash]
return c
if __name__ == "__main__":
original_tree = {
"a": {
"b1": {
"c": {}
},
"b2" : {
"c": {}
},
"b3*": {
"c*": {}
}
}
}
c_tree = CompressedTree(original_tree)
assert c_tree.lookup(("a", "b1", "c")) == (True, ("a", "b1", "c"))
assert c_tree.lookup(("a", "b1", "d")) == (False, ("a", "b1"))
print(json.dumps(c_tree.reconstruct_compressed(), indent = 4))
assert c_tree.reconstruct() == original_tree
c_tree.insert(("a", "b1", "d"))
c_tree.insert(("a", "b2", "d"))
print(json.dumps(c_tree.reconstruct(), indent = 4))
print(json.dumps(c_tree.reconstruct_compressed(), indent = 4))
print(c_tree.cache)
# test round trip
assert CompressedTree(original_tree).reconstruct() == original_tree
# test adding a key
added_keys_tree = {
"a": {
"b1": {
"c": {}
},
"b2" : {
"c": {},
"d" : {}
},
"b3*": {
"c*": {},
"d*": {}
}
}
}
c_tree = CompressedTree(original_tree)
c_tree.insert(("a", "b2", "d"))
c_tree.insert(("a", "b3*", "d*"))
assert c_tree.reconstruct() == added_keys_tree
print(c_tree.reconstruct_compressed())

View File

@ -0,0 +1 @@
from .fdb_schema_parser import FDBSchema, FDBSchemaFile, KeySpec, Key

View File

@ -0,0 +1,375 @@
import dataclasses
import json
from dataclasses import dataclass, field
from typing import Any
import pe
from pe.actions import Pack
from pe.operators import Class, Star
from .fdb_types import FDB_type_to_implementation, FDBType
@dataclass(frozen=True)
class KeySpec:
"""
Represents the specification of a single key in an FDB schema file. For example in
```
[ class, expver, stream=lwda, date, time, domain?
[ type=ofb/mfb/oai
[ obsgroup, reportype ]]]
```
class, expver, type=ofdb/mfb/oai etc are the KeySpecs
These can have additional information such as: flags like `domain?`, allowed values like `type=ofb/mfb/oai`
or specify type information with `date: ClimateMonthly`
"""
key: str
type: FDBType = field(default_factory=FDBType)
flag: str | None = None
values: tuple = field(default_factory=tuple)
comment: str = ""
def __repr__(self):
repr = self.key
if self.flag:
repr += self.flag
# if self.type:
# repr += f":{self.type}"
if self.values:
repr += "=" + "/".join(self.values)
return repr
def matches(self, key, value):
# Sanity check!
if self.key != key:
return False
# Some keys have a set of allowed values type=ofb/mfb/oai
if self.values:
if value not in self.values:
return False
# Check the formatting of values like Time or Date
if self.type and not self.type.validate(value):
return False
return True
def is_optional(self):
if self.flag is None:
return False
return "?" in self.flag
def is_allable(self):
if self.flag is None:
return False
return "*" in self.flag
@dataclass(frozen=True)
class Comment:
"Represents a comment node in the schema"
value: str
@dataclass(frozen=True)
class FDBSchemaTypeDef:
"Mapping between FDB schema key names and FDB Schema Types, i.e expver is of type Expver"
key: str
type: str
# This is the schema grammar written in PEG format
fdb_schema = pe.compile(
r"""
FDB < Line+ EOF
Line < Schema / Comment / TypeDef / empty
# Comments
Comment <- "#" ~non_eol*
non_eol <- [\x09\x20-\x7F] / non_ascii
non_ascii <- [\x80-\uD7FF\uE000-\U0010FFFF]
# Default Type Definitions
TypeDef < String ":" String ";"
# Schemas are the main attraction
# They're a tree of KeySpecs.
Schema < "[" KeySpecs (","? Schema)* "]"
# KeySpecs can be just a name i.e expver
# Can also have a type expver:int
# Or a flag expver?
# Or values expver=xxx
KeySpecs < KeySpec_ws ("," KeySpec_ws)*
KeySpec_ws < KeySpec
KeySpec <- key:String (flag:Flag)? (type:Type)? (values:Values)? ([ ]* comment:Comment)?
Flag <- ~("?" / "-" / "*")
Type <- ":" [ ]* String
Values <- "=" Value ("/" Value)*
# Low level stuff
Value <- ~([-a-zA-Z0-9_]+)
String <- ~([a-zA-Z0-9_]+)
EOF <- !.
empty <- ""
""",
actions={
"Schema": Pack(tuple),
"KeySpec": KeySpec,
"Values": Pack(tuple),
"Comment": Comment,
"TypeDef": FDBSchemaTypeDef,
},
ignore=Star(Class("\t\f\r\n ")),
# flags=pe.DEBUG,
)
def post_process(entries):
"Take the raw output from the PEG parser and split it into type definitions and schema entries."
typedefs = {}
schemas = []
for entry in entries:
match entry:
case c if isinstance(c, Comment):
pass
case t if isinstance(t, FDBSchemaTypeDef):
typedefs[t.key] = t.type
case s if isinstance(s, tuple):
schemas.append(s)
case _:
raise ValueError
return typedefs, tuple(schemas)
def determine_types(types, node):
"Recursively walk a schema tree and insert the type information."
if isinstance(node, tuple):
return [determine_types(types, n) for n in node]
return dataclasses.replace(node, type=types.get(node.key, FDBType()))
@dataclass
class Key:
key: str
value: Any
key_spec: KeySpec
reason: str
def str_value(self):
return self.key_spec.type.format(self.value)
def __bool__(self):
return self.reason in {"Matches", "Skipped", "Select All"}
def emoji(self):
return {"Matches": "", "Skipped": "⏭️", "Select All": ""}.get(
self.reason, ""
)
def info(self):
return f"{self.emoji()} {self.key:<12}= {str(self.value):<12} ({self.key_spec}) {self.reason if not self else ''}"
def __repr__(self):
return f"{self.key}={self.key_spec.type.format(self.value)}"
def as_json(self):
return dict(
key=self.key,
value=self.str_value(),
reason=self.reason,
)
class FDBSchema:
"""
Represents a parsed FDB Schema file.
Has methods to validate and convert request dictionaries to a mars request form with validation and type information.
"""
def __init__(self, string, defaults: dict[str, str] = {}):
"""
1. Use a PEG parser on a schema string,
2. Separate the output into schemas and typedefs
3. Insert any concrete implementations of types from fdb_types.py defaulting to generic string type
4. Walk the schema tree and annotate it with type information.
"""
m = fdb_schema.match(string)
g = list(m.groups())
self._str_types, schemas = post_process(g)
self.types = {
key: FDB_type_to_implementation[type]
for key, type in self._str_types.items()
}
self.schemas = determine_types(self.types, schemas)
self.defaults = defaults
def __repr__(self):
return json.dumps(
dict(schemas=self.schemas, defaults=self.defaults), indent=4, default=repr
)
@classmethod
def consume_key(
cls, key_spec: KeySpec, request: dict[str, Any]
) -> Key:
key = key_spec.key
try:
value = request[key]
except KeyError:
if key_spec.is_optional():
return Key(key_spec.key, "", key_spec, "Skipped")
if key_spec.is_allable():
return Key(key_spec.key, "", key_spec, "Select All")
else:
return Key(
key_spec.key, "", key_spec, "Key Missing"
)
if key_spec.matches(key, value):
return Key(
key_spec.key,
key_spec.type.parse(value),
key_spec,
"Matches",
)
else:
return Key(
key_spec.key, value, key_spec, "Incorrect Value"
)
@classmethod
def _DFS_match(
cls, tree: list, request: dict[str, Any]
) -> tuple[bool | list, list[Key]]:
"""Do a DFS on the schema tree, returning the deepest matching path
At each stage return whether we matched on this path, and the path itself.
When traversing the tree there are three cases to consider:
1. base case []
2. one schema [k, k, k, [k, k, k]]
3. list of schemas [[k,k,k], [k,k,k], [k,k,k]]
"""
# Case 1: Base Case
if not tree:
return True, []
# Case 2: [k, k, k, [k, k, k]]
if isinstance(tree[0], KeySpec):
node, *tree = tree
# Check if this node is in the request
match_result = cls.consume_key(node, request)
# If if isn't then terminate this path here
if not match_result:
return False, [match_result,] # fmt: skip
# Otherwise continue walking the tree and return the best result
matched, path = cls._DFS_match(tree, request)
# Don't put the key in the path if it's optional and we're skipping it.
if match_result.reason != "Skipped":
path = [match_result,] + path # fmt: skip
return matched, path
# Case 3: [[k, k, k], [k, k, k]]
branches = []
for branch in tree:
matched, branch_path = cls._DFS_match(branch, request)
# If this branch matches, terminate the DFS and use this.
if matched:
return branch, branch_path
else:
branches.append(branch_path)
# If no branch matches, return the one with the deepest match
return False, max(branches, key=len)
@classmethod
def _DFS_match_all(
cls, tree: list, request: dict[str, Any]
) -> list[list[Key]]:
"""Do a DFS on the schema tree, returning all matching paths or partial matches.
At each stage return all matching paths and the deepest partial matches.
When traversing the tree there are three cases to consider:
1. base case []
2. one schema [k, k, k, [k, k, k]]
3. list of schemas [[k,k,k], [k,k,k], [k,k,k]]
"""
# Case 1: Base Case
if not tree:
return [[]]
# Case 2: [k, k, k, [k, k, k]]
if isinstance(tree[0], KeySpec):
node, *tree = tree
# Check if this node is in the request
request_values = request.get(node.key, None)
if request_values is None:
# If the key is not in the request, return a partial match with Key Missing
return [[Key(node.key, "", node, "Key Missing")]]
# If the request value is a list, try to match each value
if isinstance(request_values, list):
all_matches = []
for value in request_values:
match_result = cls.consume_key(node, {node.key: value})
if match_result:
sub_matches = cls._DFS_match_all(tree, request)
for match in sub_matches:
if match_result.reason != "Skipped":
match.insert(0, match_result)
all_matches.append(match)
return all_matches if all_matches else [[Key(node.key, "", node, "No Match Found")]]
else:
# Handle a single value
match_result = cls.consume_key(node, request)
# If it isn't then return a partial match with Key Missing
if not match_result:
return [[Key(node.key, "", node, "Key Missing")]]
# Continue walking the tree and get all matches
all_matches = cls._DFS_match_all(tree, request)
# Prepend the current match to all further matches
for match in all_matches:
if match_result.reason != "Skipped":
match.insert(0, match_result)
return all_matches
# Case 3: [[k, k, k], [k, k, k]]
all_branch_matches = []
for branch in tree:
branch_matches = cls._DFS_match_all(branch, request)
all_branch_matches.extend(branch_matches)
# Return all of the deepest partial matches or complete matches
return all_branch_matches
def match_all(self, request: dict[str, Any]):
request = request | self.defaults
return self._DFS_match_all(self.schemas, request)
def match(self, request: dict[str, Any]):
request = request | self.defaults
return self._DFS_match(self.schemas, request)
class FDBSchemaFile(FDBSchema):
def __init__(self, path: str):
with open(path, "r") as f:
return super().__init__(f.read())

View File

@ -0,0 +1,83 @@
from dataclasses import dataclass
from typing import Any
import re
from collections import defaultdict
from datetime import datetime, date, time
@dataclass(repr=False)
class FDBType:
"""
Holds information about how to format and validate a given FDB Schema type like Time or Expver
This base type represents a string and does no validation or formatting. It's the default type.
"""
name: str = "String"
def __repr__(self) -> str:
return self.name
def validate(self, s: Any) -> bool:
try:
self.parse(s)
return True
except (ValueError, AssertionError):
return False
def format(self, s: Any) -> str:
return str(s).lower()
def parse(self, s: str) -> Any:
return s
@dataclass(repr=False)
class Expver_FDBType(FDBType):
name: str = "Expver"
def parse(self, s: str) -> str:
assert bool(re.match(".{4}", s))
return s
@dataclass(repr=False)
class Time_FDBType(FDBType):
name: str = "Time"
time_format = "%H%M"
def format(self, t: time) -> str:
return t.strftime(self.time_format)
def parse(self, s: datetime | str | int) -> time:
if isinstance(s, str):
assert len(s) == 4
return datetime.strptime(s, self.time_format).time()
if isinstance(s, datetime):
return s.time()
return self.parse(f"{s:04}")
@dataclass(repr=False)
class Date_FDBType(FDBType):
name: str = "Date"
date_format: str = "%Y%m%d"
def format(self, d: Any) -> str:
if isinstance(d, date):
return d.strftime(self.date_format)
if isinstance(d, int):
return f"{d:08}"
else:
return d
def parse(self, s: datetime | str | int) -> date:
if isinstance(s, str):
return datetime.strptime(s, self.date_format).date()
elif isinstance(s, datetime):
return s.date()
return self.parse(f"{s:08}")
FDB_type_to_implementation = defaultdict(lambda: FDBType()) | {
cls.name: cls() for cls in [Expver_FDBType, Time_FDBType, Date_FDBType]
}