329 lines
12 KiB
Python
329 lines
12 KiB
Python
import json
|
|
from collections import defaultdict
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
|
|
Tree = dict[str, "Tree"]
|
|
|
|
class RefcountedDict(dict[str, int]):
|
|
refcount: int = 1
|
|
|
|
def __repr__(self):
|
|
return f"RefcountedDict(refcount={self.refcount}, {super().__repr__()})"
|
|
|
|
def __hash__(self):
|
|
return hash(tuple(sorted(self.items())))
|
|
|
|
@dataclass
|
|
class JSONNode:
|
|
key: str
|
|
values: list[str]
|
|
children: list["JSONNode"]
|
|
|
|
class CompressedTree():
|
|
"""
|
|
A implementation of a compressed tree that supports lookup, insertion, deletion and caching.
|
|
The caching means that identical subtrees are stored only once, saving memory
|
|
This is implemented internal by storing all subtrees in a global hash table
|
|
|
|
"""
|
|
cache: dict[int, RefcountedDict]
|
|
tree: RefcountedDict
|
|
|
|
def _add_to_cache(self, level : RefcountedDict) -> int:
|
|
"Add a level {key -> hash} to the cache"
|
|
h = hash(level)
|
|
if h not in self.cache:
|
|
# Increase refcounts of the child nodes
|
|
for child_h in level.values():
|
|
self.cache[child_h].refcount += 1
|
|
self.cache[h] = RefcountedDict(level)
|
|
else:
|
|
self.cache[h].refcount += 1
|
|
return h
|
|
|
|
def _replace_in_cache(self, old_h, level : RefcountedDict) -> int:
|
|
"""
|
|
Replace the object at old_h with a different object level
|
|
If the objects this is a no-op
|
|
"""
|
|
# Start by adding the new object to the cache
|
|
new_h = self._add_to_cache(level)
|
|
|
|
# Now check if the old object needs to be garbage collected
|
|
self._decrease_refcount(old_h)
|
|
|
|
return new_h
|
|
|
|
def _decrease_refcount(self, h : int):
|
|
self.cache[h].refcount -= 1
|
|
if self.cache[h].refcount == 0:
|
|
# Recursively decrease refcounts of child nodes
|
|
for child_h in self.cache[h].values():
|
|
self._decrease_refcount(child_h)
|
|
del self.cache[h]
|
|
|
|
def cache_tree(self, tree : Tree) -> int:
|
|
"Insert the given tree (dictonary of dictionaries) (all it's children, recursively) into the hash table and return the hash key"
|
|
level = RefcountedDict({k : self.cache_tree(v) for k, v in tree.items()})
|
|
return self._add_to_cache(level)
|
|
|
|
|
|
def _cache_path(self, path : list[str]) -> int:
|
|
"Treat path = [x, y, z...] like {x : {y : {z : ...}}} and cache that"
|
|
if not path:
|
|
return self.empty_hash
|
|
k, *rest = path
|
|
return self._add_to_cache(RefcountedDict({k : self._cache_path(rest)}))
|
|
|
|
def reconstruct(self, max_depth = None) -> dict[str, dict]:
|
|
"Reconstruct the tree as a normal nested dictionary"
|
|
def reconstruct_node(h : int, depth : int) -> dict[str, dict]:
|
|
if max_depth is not None and depth > max_depth:
|
|
return {}
|
|
return {k : reconstruct_node(v, depth+1) for k, v in self.cache[h].items()}
|
|
return reconstruct_node(self.root_hash, 0)
|
|
|
|
def reconstruct_compressed(self) -> dict[str, dict]:
|
|
"Reconstruct the tree as a normal nested dictionary"
|
|
def reconstruct_node(h : int) -> dict[str, dict]:
|
|
dedup : dict[int, set[str]] = defaultdict(set)
|
|
for k, h2 in self.cache[h].items():
|
|
dedup[h2].add(k)
|
|
|
|
return {"/".join(keys) : reconstruct_node(h) for h, keys in dedup.items()}
|
|
return reconstruct_node(self.root_hash)
|
|
|
|
def reconstruct_compressed_ecmwf_style(self, max_depth=None, from_node=None) -> dict[str, dict]:
|
|
"Reconstruct the tree as a normal nested dictionary"
|
|
def reconstruct_node(h : int, depth : int) -> dict[str, dict]:
|
|
if max_depth is not None and depth > max_depth:
|
|
return {}
|
|
dedup : dict[tuple[int, str], set[str]] = defaultdict(set)
|
|
for k, h2 in self.cache[h].items():
|
|
key, value = k.split("=")
|
|
dedup[(h2, key)].add(value)
|
|
|
|
return {f"{key}={','.join(values)}" : reconstruct_node(h, depth=depth+1) for (h, key), values in dedup.items()}
|
|
return reconstruct_node(from_node or self.root_hash, depth=0)
|
|
|
|
def to_json(self, max_depth=None, from_node=None) -> dict:
|
|
def reconstruct_node(h : int, depth : int) -> list[JSONNode]:
|
|
if max_depth is not None and depth > max_depth:
|
|
return {}
|
|
dedup : dict[tuple[int, str], set[str]] = defaultdict(set)
|
|
for k, h2 in self.cache[h].items():
|
|
key, value = k.split("=")
|
|
dedup[(h2, key)].add(value)
|
|
|
|
return [JSONNode(
|
|
key = key,
|
|
values = list(values),
|
|
children = reconstruct_node(h, depth=depth+1),
|
|
) for (h, key), values in dedup.items()]
|
|
|
|
return asdict(reconstruct_node(from_node or self.root_hash, depth=0)[0])
|
|
|
|
def __init__(self, tree : Tree):
|
|
self.cache = {}
|
|
self.empty_hash = hash(RefcountedDict({}))
|
|
|
|
# Recursively cache the tree
|
|
self.root_hash = self.cache_tree(tree)
|
|
|
|
# Keep a reference to the root of the tree
|
|
self.tree = self.cache[self.root_hash]
|
|
|
|
|
|
def lookup(self, keys : tuple[str, ...]) -> tuple[bool, tuple[str, ...]]:
|
|
"""
|
|
Lookup a subtree in the tree
|
|
Returns success, path
|
|
if success == True it means the path got to the bottom of the tree and path will be equal to keys
|
|
if success == False, path will holds the keys that were found
|
|
"""
|
|
loc = self.tree
|
|
for i, key in enumerate(keys):
|
|
if key in loc:
|
|
h = loc[key] # get the hash of the subtree
|
|
loc = self.cache[h] # get the subtree
|
|
else:
|
|
return False, keys[:i], h
|
|
return True, keys, h
|
|
|
|
def keys(self, keys : tuple[str, ...] = ()) -> list[str] | None:
|
|
loc = self.tree
|
|
for i, key in enumerate(keys):
|
|
if key in loc:
|
|
h = loc[key] # get the hash of the subtree
|
|
loc = self.cache[h] # get the subtree
|
|
else:
|
|
return None
|
|
return list(loc.keys())
|
|
|
|
def multi_match(self, request : dict[str, list[str]], loc = None):
|
|
if loc is None: loc = self.tree
|
|
if loc == {}: return {"_END_" : {}}
|
|
matches = {}
|
|
for request_key, request_values in request.items():
|
|
for request_value in request_values:
|
|
meta_key = f"{request_key}={request_value}"
|
|
if meta_key in loc:
|
|
new_loc = self.cache[loc[meta_key]]
|
|
matches[meta_key] = self.multi_match(request, new_loc)
|
|
|
|
if not matches: return {k : {} for k in loc.items()}
|
|
return matches
|
|
|
|
|
|
def _insert(self, old_h : int, tree: RefcountedDict, keys : tuple[str, ...]) -> int:
|
|
"Insert keys in the subtree and return the new hash of the subtree"
|
|
key, *rest = keys
|
|
assert old_h in self.cache
|
|
|
|
# Adding a new branch to the tree
|
|
if key not in tree:
|
|
new_tree = RefcountedDict(tree | {key : self._cache_path(rest)})
|
|
|
|
else:
|
|
# Make a copy of the tree and update the subtree
|
|
new_tree = RefcountedDict(tree.copy())
|
|
subtree_h = tree[key]
|
|
subtree = self.cache[subtree_h]
|
|
new_tree[key] = self._insert(subtree_h, subtree, tuple(rest))
|
|
|
|
# no-op if the hash hasn't changed
|
|
new_h = self._replace_in_cache(old_h, new_tree)
|
|
return new_h
|
|
|
|
|
|
def insert(self, keys : tuple[str, ...]):
|
|
"""
|
|
Insert a new branch into the compressed tree
|
|
"""
|
|
already_there, path = self.lookup(keys)
|
|
if already_there:
|
|
return
|
|
# Update the tree
|
|
self.root_hash = self._insert(self.root_hash, self.tree, keys)
|
|
self.tree = self.cache[self.root_hash]
|
|
|
|
def insert_tree(self, subtree: Tree):
|
|
"""
|
|
Insert a whole tree into the compressed tree.
|
|
"""
|
|
self.root_hash = self._insert_tree(self.root_hash, self.tree, subtree)
|
|
self.tree = self.cache[self.root_hash]
|
|
|
|
def _insert_tree(self, old_h: int, tree: RefcountedDict, subtree: Tree) -> int:
|
|
"""
|
|
Recursively insert a subtree into the compressed tree and return the new hash.
|
|
"""
|
|
assert old_h in self.cache
|
|
|
|
# Make a copy of the tree to avoid modifying shared structures
|
|
new_tree = RefcountedDict(tree.copy())
|
|
for key, sub_subtree in subtree.items():
|
|
if key not in tree:
|
|
# Key is not in current tree, add the subtree
|
|
# Cache the subtree rooted at sub_subtree
|
|
subtree_h = self.cache_tree(sub_subtree)
|
|
new_tree[key] = subtree_h
|
|
else:
|
|
# Key is in tree, need to recursively merge
|
|
# Get the hash and subtree from the current tree
|
|
child_h = tree[key]
|
|
child_tree = self.cache[child_h]
|
|
# Recursively merge
|
|
new_child_h = self._insert_tree(child_h, child_tree, sub_subtree)
|
|
new_tree[key] = new_child_h
|
|
|
|
# Replace the old hash with the new one in the cache
|
|
new_h = self._replace_in_cache(old_h, new_tree)
|
|
return new_h
|
|
|
|
def save(self, path : Path):
|
|
"Save the compressed tree to a file"
|
|
with open(path, "w") as f:
|
|
json.dump({
|
|
"cache" : {k : {"refcount" : v.refcount, "dict" : v} for k, v in self.cache.items()},
|
|
"root_hash": self.root_hash
|
|
}, f)
|
|
|
|
@classmethod
|
|
def load(cls, path : Path) -> "CompressedTree":
|
|
"Load the compressed tree from a file"
|
|
with open(path) as f:
|
|
data = json.load(f)
|
|
return cls.from_json(data)
|
|
|
|
|
|
@classmethod
|
|
def from_json(cls, data : dict) -> "CompressedTree":
|
|
c = CompressedTree({})
|
|
c.cache = {}
|
|
for k, v in data["cache"].items():
|
|
c.cache[int(k)] = RefcountedDict(v["dict"])
|
|
c.cache[int(k)].refcount = v["refcount"]
|
|
|
|
c.root_hash = data["root_hash"]
|
|
c.tree = c.cache[c.root_hash]
|
|
return c
|
|
|
|
|
|
if __name__ == "__main__":
|
|
original_tree = {
|
|
"a": {
|
|
"b1": {
|
|
"c": {}
|
|
},
|
|
"b2" : {
|
|
"c": {}
|
|
},
|
|
"b3*": {
|
|
"c*": {}
|
|
}
|
|
}
|
|
}
|
|
|
|
c_tree = CompressedTree(original_tree)
|
|
|
|
assert c_tree.lookup(("a", "b1", "c")) == (True, ("a", "b1", "c"))
|
|
assert c_tree.lookup(("a", "b1", "d")) == (False, ("a", "b1"))
|
|
|
|
print(json.dumps(c_tree.reconstruct_compressed(), indent = 4))
|
|
|
|
assert c_tree.reconstruct() == original_tree
|
|
|
|
c_tree.insert(("a", "b1", "d"))
|
|
c_tree.insert(("a", "b2", "d"))
|
|
print(json.dumps(c_tree.reconstruct(), indent = 4))
|
|
|
|
print(json.dumps(c_tree.reconstruct_compressed(), indent = 4))
|
|
print(c_tree.cache)
|
|
|
|
# test round trip
|
|
assert CompressedTree(original_tree).reconstruct() == original_tree
|
|
|
|
# test adding a key
|
|
added_keys_tree = {
|
|
"a": {
|
|
"b1": {
|
|
"c": {}
|
|
},
|
|
"b2" : {
|
|
"c": {},
|
|
"d" : {}
|
|
},
|
|
"b3*": {
|
|
"c*": {},
|
|
"d*": {}
|
|
}
|
|
}
|
|
}
|
|
c_tree = CompressedTree(original_tree)
|
|
c_tree.insert(("a", "b2", "d"))
|
|
c_tree.insert(("a", "b3*", "d*"))
|
|
assert c_tree.reconstruct() == added_keys_tree
|
|
|
|
print(c_tree.reconstruct_compressed()) |