new code
This commit is contained in:
parent
fcdf4e0d51
commit
f51f5dcb42
395
notebooks/DataCubeTree.py
Normal file
395
notebooks/DataCubeTree.py
Normal file
@ -0,0 +1,395 @@
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Any, Callable, Iterable, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HTML():
|
||||
html: str
|
||||
def _repr_html_(self):
|
||||
return self.html
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Values(ABC):
|
||||
@abstractmethod
|
||||
def summary(self) -> str:
|
||||
pass
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, value: Any) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_strings(self, values: list[str]) -> list['Values']:
|
||||
pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Enum(Values):
|
||||
"""
|
||||
The simplest kind of key value is just a list of strings.
|
||||
summary -> string1/string2/string....
|
||||
"""
|
||||
values: list[Any]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.values)
|
||||
def summary(self) -> str:
|
||||
return '/'.join(sorted(self.values))
|
||||
def __contains__(self, value: Any) -> bool:
|
||||
return value in self.values
|
||||
def from_strings(self, values: list[str]) -> list['Values']:
|
||||
return [Enum(values)]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Range(Values, ABC):
|
||||
dtype: str = dataclasses.field(kw_only=True)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DateRange(Range):
|
||||
start: date
|
||||
end: date
|
||||
step: timedelta
|
||||
dtype: Literal["date"] = dataclasses.field(kw_only=True, default="date")
|
||||
|
||||
@classmethod
|
||||
def from_strings(self, values: list[str]) -> list['DateRange']:
|
||||
dates = sorted([datetime.strptime(v, "%Y%m%d") for v in values])
|
||||
if len(dates) < 2:
|
||||
return [DateRange(
|
||||
start=dates[0],
|
||||
end=dates[0],
|
||||
step=timedelta(days=0)
|
||||
)]
|
||||
|
||||
ranges = []
|
||||
current_range, dates = [dates[0],], dates[1:]
|
||||
while len(dates) > 1:
|
||||
if dates[0] - current_range[-1] == timedelta(days=1):
|
||||
current_range.append(dates.pop(0))
|
||||
|
||||
elif len(current_range) == 1:
|
||||
ranges.append(DateRange(
|
||||
start=current_range[0],
|
||||
end=current_range[0],
|
||||
step=timedelta(days=0)
|
||||
))
|
||||
current_range = [dates.pop(0),]
|
||||
|
||||
else:
|
||||
ranges.append(DateRange(
|
||||
start=current_range[0],
|
||||
end=current_range[-1],
|
||||
step=timedelta(days=1)
|
||||
))
|
||||
current_range = [dates.pop(0),]
|
||||
return ranges
|
||||
|
||||
def __contains__(self, value: Any) -> bool:
|
||||
v = datetime.strptime(value, "%Y%m%d").date()
|
||||
return self.start <= v <= self.end and (v - self.start) % self.step == 0
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return (self.end - self.start) // self.step
|
||||
|
||||
def summary(self) -> str:
|
||||
def fmt(d): return d.strftime("%Y%m%d")
|
||||
if self.step == timedelta(days=0):
|
||||
return f"{fmt(self.start)}"
|
||||
if self.step == timedelta(days=1):
|
||||
return f"{fmt(self.start)}/to/{fmt(self.end)}"
|
||||
|
||||
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step // timedelta(days=1)}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TimeRange(Range):
|
||||
start: int
|
||||
end: int
|
||||
step: int
|
||||
dtype: Literal["time"] = dataclasses.field(kw_only=True, default="time")
|
||||
|
||||
@classmethod
|
||||
def from_strings(self, values: list[str]) -> list['TimeRange']:
|
||||
if len(values) == 0: return []
|
||||
|
||||
times = sorted([int(v) for v in values])
|
||||
if len(times) < 2:
|
||||
return [TimeRange(
|
||||
start=times[0],
|
||||
end=times[0],
|
||||
step=100
|
||||
)]
|
||||
|
||||
ranges = []
|
||||
current_range, times = [times[0],], times[1:]
|
||||
while len(times) > 1:
|
||||
if times[0] - current_range[-1] == 1:
|
||||
current_range.append(times.pop(0))
|
||||
|
||||
elif len(current_range) == 1:
|
||||
ranges.append(TimeRange(
|
||||
start=current_range[0],
|
||||
end=current_range[0],
|
||||
step=0
|
||||
))
|
||||
current_range = [times.pop(0),]
|
||||
|
||||
else:
|
||||
ranges.append(TimeRange(
|
||||
start=current_range[0],
|
||||
end=current_range[-1],
|
||||
step=1
|
||||
))
|
||||
current_range = [times.pop(0),]
|
||||
return ranges
|
||||
|
||||
def __len__(self) -> int:
|
||||
return (self.end - self.start) // self.step
|
||||
|
||||
def summary(self) -> str:
|
||||
def fmt(d): return f"{d:04d}"
|
||||
if self.step == 0:
|
||||
return f"{fmt(self.start)}"
|
||||
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}"
|
||||
|
||||
def __contains__(self, value: Any) -> bool:
|
||||
v = int(value)
|
||||
return self.start <= v <= self.end and (v - self.start) % self.step == 0
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntRange(Range):
|
||||
dtype: Literal["int"]
|
||||
start: int
|
||||
end: int
|
||||
step: int
|
||||
dtype: Literal["int"] = dataclasses.field(kw_only=True, default="int")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return (self.end - self.start) // self.step
|
||||
|
||||
def summary(self) -> str:
|
||||
def fmt(d): return d.strftime("%Y%m%d")
|
||||
return f"{fmt(self.start)}/to/{fmt(self.end)}/by/{self.step}"
|
||||
|
||||
def __contains__(self, value: Any) -> bool:
|
||||
v = int(value)
|
||||
return self.start <= v <= self.end and (v - self.start) % self.step == 0
|
||||
|
||||
|
||||
def values_from_json(obj) -> Values:
|
||||
if isinstance(obj, list):
|
||||
return Enum(obj)
|
||||
|
||||
match obj["dtype"]:
|
||||
case "date": return DateRange(**obj)
|
||||
case "time": return TimeRange(**obj)
|
||||
case "int": return IntRange(**obj)
|
||||
case _: raise ValueError(f"Unknown dtype {obj['dtype']}")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Node:
|
||||
key: str
|
||||
values: Values # Must support len()
|
||||
metadata: dict[str, str] # Applies to all children
|
||||
payload: list[Any] # List of size product(len(n.values) for n in ancestors(self))
|
||||
children: list['Node']
|
||||
|
||||
def summarize_node(node: Node) -> tuple[str, Node]:
|
||||
"""
|
||||
Extracts a summarized representation of the node while collapsing single-child paths.
|
||||
Returns the summary string and the last node in the chain that has multiple children.
|
||||
"""
|
||||
summary = []
|
||||
|
||||
while True:
|
||||
values_summary = node.values.summary()
|
||||
if len(values_summary) > 50:
|
||||
values_summary = values_summary[:50] + "..."
|
||||
summary.append(f"{node.key}={values_summary}")
|
||||
|
||||
# Move down if there's exactly one child, otherwise stop
|
||||
if len(node.children) != 1:
|
||||
break
|
||||
node = node.children[0]
|
||||
|
||||
return ", ".join(summary), node
|
||||
|
||||
def node_tree_to_string(node : Node, prefix : str = "", depth = None) -> Iterable[str]:
|
||||
summary, node = summarize_node(node)
|
||||
|
||||
if depth is not None and depth <= 0:
|
||||
yield summary + " - ...\n"
|
||||
return
|
||||
# Special case for nodes with only a single child, this makes the printed representation more compact
|
||||
elif len(node.children) == 1:
|
||||
yield summary + ", "
|
||||
yield from node_tree_to_string(node.children[0], prefix, depth = depth)
|
||||
return
|
||||
else:
|
||||
yield summary + "\n"
|
||||
|
||||
for index, child in enumerate(node.children):
|
||||
connector = "└── " if index == len(node.children) - 1 else "├── "
|
||||
yield prefix + connector
|
||||
extension = " " if index == len(node.children) - 1 else "│ "
|
||||
yield from node_tree_to_string(child, prefix + extension, depth = depth - 1 if depth is not None else None)
|
||||
|
||||
def node_tree_to_html(node : Node, prefix : str = "", depth = 1, connector = "") -> Iterable[str]:
|
||||
summary, node = summarize_node(node)
|
||||
|
||||
if len(node.children) == 0:
|
||||
yield f'<span class="leaf">{connector}{summary}</span>'
|
||||
return
|
||||
else:
|
||||
open = "open" if depth > 0 else ""
|
||||
yield f"<details {open}><summary>{connector}{summary}</summary>"
|
||||
|
||||
for index, child in enumerate(node.children):
|
||||
connector = "└── " if index == len(node.children) - 1 else "├── "
|
||||
extension = " " if index == len(node.children) - 1 else "│ "
|
||||
yield from node_tree_to_html(child, prefix + extension, depth = depth - 1, connector = prefix+connector)
|
||||
yield "</details>"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CompressedTree:
|
||||
root: Node
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: dict) -> 'CompressedTree':
|
||||
def from_json(json: dict) -> Node:
|
||||
return Node(
|
||||
key=json["key"],
|
||||
values=values_from_json(json["values"]),
|
||||
metadata=json["metadata"] if "metadata" in json else {},
|
||||
payload=json["payload"] if "payload" in json else [],
|
||||
children=[from_json(c) for c in json["children"]]
|
||||
)
|
||||
return CompressedTree(root=from_json(json))
|
||||
|
||||
def __str__(self):
|
||||
return "".join(node_tree_to_string(node=self.root))
|
||||
|
||||
def html(self, depth = 2) -> HTML:
|
||||
return HTML(self._repr_html_(depth = depth))
|
||||
|
||||
def _repr_html_(self, depth = 2):
|
||||
css = """
|
||||
<style>
|
||||
.qubed-tree-view {
|
||||
font-family: monospace;
|
||||
white-space: pre;
|
||||
}
|
||||
.qubed-tree-view details {
|
||||
# display: inline;
|
||||
margin-left: 0;
|
||||
}
|
||||
.qubed-tree-view summary {
|
||||
list-style: none;
|
||||
cursor: pointer;
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
text-wrap: nowrap;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.qubed-tree-view .leaf {
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
text-wrap: nowrap;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.qubed-tree-view summary:hover,span.leaf:hover {
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
.qubed-tree-view details > summary::after {
|
||||
content: ' ';
|
||||
}
|
||||
.qubed-tree-view details:not([open]) > summary::after {
|
||||
content: " ▼";
|
||||
}
|
||||
</style>
|
||||
|
||||
"""
|
||||
nodes = "".join(cc for c in self.root.children for cc in node_tree_to_html(node=c, depth=depth))
|
||||
return f"{css}<pre class='qubed-tree-view'>{nodes}</pre>"
|
||||
|
||||
def print(self, depth = None):
|
||||
print("".join(cc for c in self.root.children for cc in node_tree_to_string(node=c, depth = depth)))
|
||||
|
||||
def transform(self, func: Callable[[Node], Node]) -> 'CompressedTree':
|
||||
"Call a function on every node of the tree, any changes to the children of a node will be ignored."
|
||||
def transform(node: Node) -> Node:
|
||||
new_node = func(node)
|
||||
return dataclasses.replace(new_node, children = [transform(c) for c in node.children])
|
||||
return CompressedTree(root=transform(self.root))
|
||||
|
||||
def guess_datatypes(self) -> 'CompressedTree':
|
||||
def guess_datatypes(node: Node) -> list[Node]:
|
||||
# Try to convert enum values into more structured types
|
||||
children = [cc for c in node.children for cc in guess_datatypes(c)]
|
||||
|
||||
if isinstance(node.values, Enum):
|
||||
match node.key:
|
||||
case "time": range_class = TimeRange
|
||||
case "date": range_class = DateRange
|
||||
case _: range_class = None
|
||||
|
||||
if range_class is not None:
|
||||
return [
|
||||
dataclasses.replace(node, values = range, children = children)
|
||||
for range in range_class.from_strings(node.values.values)
|
||||
]
|
||||
return [dataclasses.replace(node, children = children)]
|
||||
|
||||
children = [cc for c in self.root.children for cc in guess_datatypes(c)]
|
||||
return CompressedTree(root=dataclasses.replace(self.root, children = children))
|
||||
|
||||
|
||||
def select(self, selection : dict[str, str | list[str]], mode: Literal["strict", "relaxed"] = "relaxed") -> 'CompressedTree':
|
||||
# make all values lists
|
||||
selection = {k : v if isinstance(v, list) else [v] for k,v in selection.items()}
|
||||
|
||||
def not_none(xs): return [x for x in xs if x is not None]
|
||||
|
||||
def select(node: Node) -> Node | None:
|
||||
# Check if the key is specified in the selection
|
||||
if node.key not in selection:
|
||||
if mode == "strict":
|
||||
return None
|
||||
return dataclasses.replace(node, children = not_none(select(c) for c in node.children))
|
||||
|
||||
# If the key is specified, check if any of the values match
|
||||
values = Enum([ c for c in selection[node.key] if c in node.values])
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
return dataclasses.replace(node, values = values, children = not_none(select(c) for c in node.children))
|
||||
|
||||
return CompressedTree(root=dataclasses.replace(self.root, children = not_none(select(c) for c in self.root.children)))
|
||||
|
||||
def to_list_of_cubes(self):
|
||||
def to_list_of_cubes(node: Node) -> list[list[Node]]:
|
||||
return [[node] + sub_cube for c in node.children for sub_cube in to_list_of_cubes(c)]
|
||||
|
||||
return to_list_of_cubes(self.root)
|
||||
|
||||
def info(self):
|
||||
cubes = self.to_list_of_cubes()
|
||||
print(f"Number of distinct paths: {len(cubes)}")
|
||||
|
||||
|
||||
|
||||
|
||||
# What should the interace look like?
|
||||
|
||||
# tree = CompressedTree.from_json(...)
|
||||
# tree = CompressedTree.from_protobuf(...)
|
||||
|
||||
# tree.print(depth = 5) # Prints a nice tree representation
|
231
notebooks/test.ipynb
Normal file
231
notebooks/test.ipynb
Normal file
File diff suppressed because one or more lines are too long
12
tree_compresser/tests/new_format.py
Normal file
12
tree_compresser/tests/new_format.py
Normal file
@ -0,0 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import orjson as json
|
||||
from tree_traverser.DataCubeTree import CompressedTree
|
||||
|
||||
data_path = Path("./config/climate-dt/new_format.json")
|
||||
with data_path.open("r") as f:
|
||||
compressed_tree = CompressedTree.from_json(json.loads(f.read()))
|
||||
|
||||
compressed_tree = compressed_tree.guess_datatypes()
|
||||
|
||||
compressed_tree.print(depth = 10)
|
Loading…
x
Reference in New Issue
Block a user