Add creation from tree representation
This commit is contained in:
parent
9beaaa2e10
commit
6b98f7b7a9
@ -2,7 +2,7 @@ import dataclasses
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Iterable, Literal, Sequence
|
||||
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
@ -104,17 +104,66 @@ class Qube:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "Qube":
|
||||
def from_dict(d: dict) -> list[Qube]:
|
||||
return [
|
||||
Qube.make(
|
||||
key=k.split("=")[0],
|
||||
values=QEnum((k.split("=")[1].split("/"))),
|
||||
def from_dict(d: dict) -> Iterator[Qube]:
|
||||
for k, children in d.items():
|
||||
key, values = k.split("=")
|
||||
values = values.split("/")
|
||||
if values == ["*"]:
|
||||
values = WildcardGroup()
|
||||
else:
|
||||
values = QEnum(values)
|
||||
|
||||
yield Qube.make(
|
||||
key=key,
|
||||
values=values,
|
||||
children=from_dict(children),
|
||||
)
|
||||
for k, children in d.items()
|
||||
]
|
||||
|
||||
return Qube.root_node(from_dict(d))
|
||||
return Qube.root_node(list(from_dict(d)))
|
||||
|
||||
@classmethod
|
||||
def from_tree(cls, tree_str):
|
||||
lines = tree_str.splitlines()
|
||||
stack = []
|
||||
root = {}
|
||||
|
||||
initial_indent = None
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
continue
|
||||
# Remove tree characters and measure indent level
|
||||
stripped = line.lstrip(" │├└─")
|
||||
indent = (len(line) - len(stripped)) // 4
|
||||
if initial_indent is None:
|
||||
initial_indent = indent
|
||||
print(f"Initial indent {initial_indent}")
|
||||
indent = indent - initial_indent
|
||||
|
||||
# Split multiple key=value parts into nested structure
|
||||
keys = [item.strip() for item in stripped.split(",")]
|
||||
current = bottom = {}
|
||||
for key in reversed(keys):
|
||||
current = {key: current}
|
||||
|
||||
# Adjust the stack to current indent level
|
||||
# print(len(stack), stack)
|
||||
while len(stack) > indent:
|
||||
stack.pop()
|
||||
|
||||
if stack:
|
||||
# Add to the dictionary at current stack level
|
||||
parent = stack[-1]
|
||||
key = list(current.keys())[0]
|
||||
parent[key] = current[key]
|
||||
else:
|
||||
# Top level
|
||||
key = list(current.keys())[0]
|
||||
root = current[key]
|
||||
|
||||
# Push to the stack
|
||||
stack.append(bottom)
|
||||
|
||||
return cls.from_dict(root)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "Qube":
|
||||
@ -210,7 +259,7 @@ class Qube:
|
||||
break
|
||||
else:
|
||||
raise KeyError(
|
||||
f"Key '{key}' not found in children of '{current.key}'"
|
||||
f"Key '{key}' not found in children of '{current.key}', available keys are {[c.key for c in current.children]}"
|
||||
)
|
||||
return Qube.root_node(current.children)
|
||||
|
||||
@ -261,21 +310,28 @@ class Qube:
|
||||
selection: dict[str, str | list[str]],
|
||||
mode: Literal["strict", "relaxed"] = "relaxed",
|
||||
prune=True,
|
||||
consume=True,
|
||||
consume=False,
|
||||
) -> "Qube":
|
||||
# make all values lists
|
||||
selection = {k: v if isinstance(v, list) else [v] for k, v in selection.items()}
|
||||
selection: dict[str, list[str]] = {
|
||||
k: v if isinstance(v, list) else [v] for k, v in selection.items()
|
||||
}
|
||||
|
||||
def not_none(xs):
|
||||
return tuple(x for x in xs if x is not None)
|
||||
|
||||
def select(node: Qube) -> Qube | None:
|
||||
def select(node: Qube, selection: dict[str, list[str]]) -> Qube | None:
|
||||
# If this node has no children but there are still parts of the request
|
||||
# that have not been consumed, then prune this whole branch
|
||||
if consume and not node.children and selection:
|
||||
return None
|
||||
|
||||
# Check if the key is specified in the selection
|
||||
if node.key not in selection:
|
||||
if mode == "strict":
|
||||
return None
|
||||
|
||||
new_children = not_none(select(c) for c in node.children)
|
||||
new_children = not_none(select(c, selection) for c in node.children)
|
||||
|
||||
# prune==true then remove any non-leaf nodes
|
||||
# which have had all their children removed
|
||||
@ -291,12 +347,16 @@ class Qube:
|
||||
return None
|
||||
|
||||
data = dataclasses.replace(node.data, values=values)
|
||||
if consume:
|
||||
selection = {k: v for k, v in selection.items() if k != node.key}
|
||||
return dataclasses.replace(
|
||||
node, data=data, children=not_none(select(c) for c in node.children)
|
||||
node,
|
||||
data=data,
|
||||
children=not_none(select(c, selection) for c in node.children),
|
||||
)
|
||||
|
||||
return dataclasses.replace(
|
||||
self, children=not_none(select(c) for c in self.children)
|
||||
self, children=not_none(select(c, selection) for c in self.children)
|
||||
)
|
||||
|
||||
def span(self, key: str) -> list[str]:
|
||||
|
56
tests/test_creation.py
Normal file
56
tests/test_creation.py
Normal file
@ -0,0 +1,56 @@
|
||||
from qubed import Qube
|
||||
|
||||
|
||||
def test_simple():
|
||||
q = Qube.from_tree("""
|
||||
root, frequency=6:00:00
|
||||
├── levtype=pl, param=t, levelist=850, threshold=-2/-4/-8/2/4/8
|
||||
└── levtype=sfc
|
||||
├── param=10u/10v, threshold=10/15
|
||||
├── param=2t, threshold=273.15
|
||||
└── param=tp, threshold=0.1/1/10/100/20/25/5/50
|
||||
""")
|
||||
q.print()
|
||||
r = Qube.from_dict(
|
||||
{
|
||||
"frequency=6:00:00": {
|
||||
"levtype=pl": {
|
||||
"param=t": {"levelist=850": {"threshold=-8/-4/-2/2/4/8": {}}}
|
||||
},
|
||||
"levtype=sfc": {
|
||||
"param=10u/10v": {"threshold=10/15": {}},
|
||||
"param=2t": {"threshold=273.15": {}},
|
||||
"param=tp": {"threshold=0.1/1/5/10/20/25/50/100": {}},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert q == r
|
||||
|
||||
|
||||
def test_simple_2():
|
||||
models = Qube.from_datacube(
|
||||
dict(
|
||||
param="10u/10v/2d/2t/cp/msl/skt/sp/tcw/tp".split("/"),
|
||||
threshold="*",
|
||||
levtype="sfc",
|
||||
frequency="6:00:00",
|
||||
)
|
||||
) | Qube.from_datacube(
|
||||
dict(
|
||||
param="q/t/u/v/w/z".split("/"),
|
||||
threshold="*",
|
||||
levtype="pl",
|
||||
level="50/100/150/200/250/300/400/500/600/700/850".split("/"),
|
||||
frequency="6:00:00",
|
||||
)
|
||||
)
|
||||
|
||||
models2 = Qube.from_tree("""
|
||||
models
|
||||
├── param=10u/10v/2d/2t/cp/msl/skt/sp/tcw/tp, threshold=*, levtype=sfc, frequency=6:00:00
|
||||
└── param=q/t/u/v/w/z, threshold=*, levtype=pl, level=100/150/200/250/300/400/50/500/600/700/850, frequency=6:00:00
|
||||
""")
|
||||
|
||||
assert models == models2
|
Loading…
x
Reference in New Issue
Block a user