2024-12-11 14:09:45 +00:00

306 lines
11 KiB
Python

import json
from collections import defaultdict
from pathlib import Path
Tree = dict[str, "Tree"]
class RefcountedDict(dict[str, int]):
refcount: int = 1
def __repr__(self):
return f"RefcountedDict(refcount={self.refcount}, {super().__repr__()})"
def __hash__(self):
return hash(tuple(sorted(self.items())))
class CompressedTree():
"""
A implementation of a compressed tree that supports lookup, insertion, deletion and caching.
The caching means that identical subtrees are stored only once, saving memory
This is implemented internal by storing all subtrees in a global hash table
"""
cache: dict[int, RefcountedDict]
tree: RefcountedDict
def _add_to_cache(self, level : RefcountedDict) -> int:
"Add a level {key -> hash} to the cache"
h = hash(level)
if h not in self.cache:
# Increase refcounts of the child nodes
for child_h in level.values():
self.cache[child_h].refcount += 1
self.cache[h] = RefcountedDict(level)
else:
self.cache[h].refcount += 1
return h
def _replace_in_cache(self, old_h, level : RefcountedDict) -> int:
"""
Replace the object at old_h with a different object level
If the objects this is a no-op
"""
# Start by adding the new object to the cache
new_h = self._add_to_cache(level)
# Now check if the old object needs to be garbage collected
self._decrease_refcount(old_h)
return new_h
def _decrease_refcount(self, h : int):
self.cache[h].refcount -= 1
if self.cache[h].refcount == 0:
# Recursively decrease refcounts of child nodes
for child_h in self.cache[h].values():
self._decrease_refcount(child_h)
del self.cache[h]
def cache_tree(self, tree : Tree) -> int:
"Insert the given tree (dictonary of dictionaries) (all it's children, recursively) into the hash table and return the hash key"
level = RefcountedDict({k : self.cache_tree(v) for k, v in tree.items()})
return self._add_to_cache(level)
def _cache_path(self, path : list[str]) -> int:
"Treat path = [x, y, z...] like {x : {y : {z : ...}}} and cache that"
if not path:
return self.empty_hash
k, *rest = path
return self._add_to_cache(RefcountedDict({k : self._cache_path(rest)}))
def reconstruct(self, max_depth = None) -> dict[str, dict]:
"Reconstruct the tree as a normal nested dictionary"
def reconstruct_node(h : int, depth : int) -> dict[str, dict]:
if max_depth is not None and depth > max_depth:
return {}
return {k : reconstruct_node(v, depth+1) for k, v in self.cache[h].items()}
return reconstruct_node(self.root_hash, 0)
def reconstruct_compressed(self) -> dict[str, dict]:
"Reconstruct the tree as a normal nested dictionary"
def reconstruct_node(h : int) -> dict[str, dict]:
dedup : dict[int, set[str]] = defaultdict(set)
for k, h2 in self.cache[h].items():
dedup[h2].add(k)
return {"/".join(keys) : reconstruct_node(h) for h, keys in dedup.items()}
return reconstruct_node(self.root_hash)
def reconstruct_compressed_ecmwf_style(self, max_depth=None, from_node=None) -> dict[str, dict]:
"Reconstruct the tree as a normal nested dictionary"
def reconstruct_node(h : int, depth : int) -> dict[str, dict]:
if max_depth is not None and depth > max_depth:
return {}
dedup : dict[tuple[int, str], set[str]] = defaultdict(set)
for k, h2 in self.cache[h].items():
key, value = k.split("=")
dedup[(h2, key)].add(value)
return {f"{key}={','.join(values)}" : reconstruct_node(h, depth=depth+1) for (h, key), values in dedup.items()}
return reconstruct_node(from_node or self.root_hash, depth=0)
def __init__(self, tree : Tree):
self.cache = {}
self.empty_hash = hash(RefcountedDict({}))
# Recursively cache the tree
self.root_hash = self.cache_tree(tree)
# Keep a reference to the root of the tree
self.tree = self.cache[self.root_hash]
def lookup(self, keys : tuple[str, ...]) -> tuple[bool, tuple[str, ...]]:
"""
Lookup a subtree in the tree
Returns success, path
if success == True it means the path got to the bottom of the tree and path will be equal to keys
if success == False, path will holds the keys that were found
"""
loc = self.tree
for i, key in enumerate(keys):
if key in loc:
h = loc[key] # get the hash of the subtree
loc = self.cache[h] # get the subtree
else:
return False, keys[:i], h
return True, keys, h
def keys(self, keys : tuple[str, ...] = ()) -> list[str] | None:
loc = self.tree
for i, key in enumerate(keys):
if key in loc:
h = loc[key] # get the hash of the subtree
loc = self.cache[h] # get the subtree
else:
return None
return list(loc.keys())
def multi_match(self, request : dict[str, list[str]], loc = None):
if not loc: return {"_END_" : {}}
if loc is None: loc = self.tree
matches = {}
for request_key, request_values in request.items():
for request_value in request_values:
meta_key = f"{request_key}={request_value}"
if meta_key in loc:
new_loc = self.cache[loc[meta_key]]
matches[meta_key] = self.multi_match(request, new_loc)
if not matches: return {k : {} for k in loc.items()}
return matches
def _insert(self, old_h : int, tree: RefcountedDict, keys : tuple[str, ...]) -> int:
"Insert keys in the subtree and return the new hash of the subtree"
key, *rest = keys
assert old_h in self.cache
# Adding a new branch to the tree
if key not in tree:
new_tree = RefcountedDict(tree | {key : self._cache_path(rest)})
else:
# Make a copy of the tree and update the subtree
new_tree = RefcountedDict(tree.copy())
subtree_h = tree[key]
subtree = self.cache[subtree_h]
new_tree[key] = self._insert(subtree_h, subtree, tuple(rest))
# no-op if the hash hasn't changed
new_h = self._replace_in_cache(old_h, new_tree)
return new_h
def insert(self, keys : tuple[str, ...]):
"""
Insert a new branch into the compressed tree
"""
already_there, path = self.lookup(keys)
if already_there:
return
# Update the tree
self.root_hash = self._insert(self.root_hash, self.tree, keys)
self.tree = self.cache[self.root_hash]
def insert_tree(self, subtree: Tree):
"""
Insert a whole tree into the compressed tree.
"""
self.root_hash = self._insert_tree(self.root_hash, self.tree, subtree)
self.tree = self.cache[self.root_hash]
def _insert_tree(self, old_h: int, tree: RefcountedDict, subtree: Tree) -> int:
"""
Recursively insert a subtree into the compressed tree and return the new hash.
"""
assert old_h in self.cache
# Make a copy of the tree to avoid modifying shared structures
new_tree = RefcountedDict(tree.copy())
for key, sub_subtree in subtree.items():
if key not in tree:
# Key is not in current tree, add the subtree
# Cache the subtree rooted at sub_subtree
subtree_h = self.cache_tree(sub_subtree)
new_tree[key] = subtree_h
else:
# Key is in tree, need to recursively merge
# Get the hash and subtree from the current tree
child_h = tree[key]
child_tree = self.cache[child_h]
# Recursively merge
new_child_h = self._insert_tree(child_h, child_tree, sub_subtree)
new_tree[key] = new_child_h
# Replace the old hash with the new one in the cache
new_h = self._replace_in_cache(old_h, new_tree)
return new_h
def save(self, path : Path):
"Save the compressed tree to a file"
with open(path, "w") as f:
json.dump({
"cache" : {k : {"refcount" : v.refcount, "dict" : v} for k, v in self.cache.items()},
"root_hash": self.root_hash
}, f)
@classmethod
def load(cls, path : Path) -> "CompressedTree":
"Load the compressed tree from a file"
with open(path) as f:
data = json.load(f)
return cls.from_json(data)
@classmethod
def from_json(cls, data : dict) -> "CompressedTree":
c = CompressedTree({})
c.cache = {}
for k, v in data["cache"].items():
c.cache[int(k)] = RefcountedDict(v["dict"])
c.cache[int(k)].refcount = v["refcount"]
c.root_hash = data["root_hash"]
c.tree = c.cache[c.root_hash]
return c
if __name__ == "__main__":
original_tree = {
"a": {
"b1": {
"c": {}
},
"b2" : {
"c": {}
},
"b3*": {
"c*": {}
}
}
}
c_tree = CompressedTree(original_tree)
assert c_tree.lookup(("a", "b1", "c")) == (True, ("a", "b1", "c"))
assert c_tree.lookup(("a", "b1", "d")) == (False, ("a", "b1"))
print(json.dumps(c_tree.reconstruct_compressed(), indent = 4))
assert c_tree.reconstruct() == original_tree
c_tree.insert(("a", "b1", "d"))
c_tree.insert(("a", "b2", "d"))
print(json.dumps(c_tree.reconstruct(), indent = 4))
print(json.dumps(c_tree.reconstruct_compressed(), indent = 4))
print(c_tree.cache)
# test round trip
assert CompressedTree(original_tree).reconstruct() == original_tree
# test adding a key
added_keys_tree = {
"a": {
"b1": {
"c": {}
},
"b2" : {
"c": {},
"d" : {}
},
"b3*": {
"c*": {},
"d*": {}
}
}
}
c_tree = CompressedTree(original_tree)
c_tree.insert(("a", "b2", "d"))
c_tree.insert(("a", "b3*", "d*"))
assert c_tree.reconstruct() == added_keys_tree
print(c_tree.reconstruct_compressed())