Add convert_dtypes and selection with functions
This commit is contained in:
parent
d2f3165fe8
commit
df5360f29a
@ -1,8 +1,9 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence
|
from typing import Any, Iterable, Iterator, Literal, Sequence
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
@ -233,6 +234,36 @@ class Qube:
|
|||||||
else:
|
else:
|
||||||
yield leaf
|
yield leaf
|
||||||
|
|
||||||
|
def leaves_with_metadata(
|
||||||
|
self, indices=()
|
||||||
|
) -> Iterable[tuple[dict[str, str], dict[str, str]]]:
|
||||||
|
if self.key == "root":
|
||||||
|
for c in self.children:
|
||||||
|
for leaf in c.leaves_with_metadata(indices=()):
|
||||||
|
yield leaf
|
||||||
|
return
|
||||||
|
|
||||||
|
for index, value in enumerate(self.values):
|
||||||
|
# print(self.key, index, indices, value)
|
||||||
|
# print({k: np.shape(v) for k, v in self.metadata.items()})
|
||||||
|
indexed_metadata = {
|
||||||
|
k: vs[indices + (index,)] for k, vs in self.metadata.items()
|
||||||
|
}
|
||||||
|
indexed_metadata = {
|
||||||
|
k: v.item() if v.shape == () else v for k, v in indexed_metadata.items()
|
||||||
|
}
|
||||||
|
if not self.children:
|
||||||
|
yield {self.key: value}, indexed_metadata
|
||||||
|
|
||||||
|
for child in self.children:
|
||||||
|
for leaf, metadata in child.leaves_with_metadata(
|
||||||
|
indices=indices + (index,)
|
||||||
|
):
|
||||||
|
if self.key != "root":
|
||||||
|
yield {self.key: value, **leaf}, metadata | indexed_metadata
|
||||||
|
else:
|
||||||
|
yield leaf, metadata
|
||||||
|
|
||||||
def datacubes(self) -> "Qube":
|
def datacubes(self) -> "Qube":
|
||||||
def to_list_of_cubes(node: Qube) -> Iterable[Qube]:
|
def to_list_of_cubes(node: Qube) -> Iterable[Qube]:
|
||||||
if not node.children:
|
if not node.children:
|
||||||
@ -241,7 +272,7 @@ class Qube:
|
|||||||
for c in node.children:
|
for c in node.children:
|
||||||
# print(c)
|
# print(c)
|
||||||
for sub_cube in to_list_of_cubes(c):
|
for sub_cube in to_list_of_cubes(c):
|
||||||
yield dataclasses.replace(node, children=[sub_cube])
|
yield node.replace(children=[sub_cube])
|
||||||
|
|
||||||
return Qube.root_node((q for c in self.children for q in to_list_of_cubes(c)))
|
return Qube.root_node((q for c in self.children for q in to_list_of_cubes(c)))
|
||||||
|
|
||||||
@ -294,7 +325,7 @@ class Qube:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def transform(node: Qube) -> list[Qube]:
|
def transform(node: Qube) -> list[Qube]:
|
||||||
children = [cc for c in node.children for cc in transform(c)]
|
children = tuple(sorted(cc for c in node.children for cc in transform(c)))
|
||||||
new_nodes = func(node)
|
new_nodes = func(node)
|
||||||
if isinstance(new_nodes, Qube):
|
if isinstance(new_nodes, Qube):
|
||||||
new_nodes = [new_nodes]
|
new_nodes = [new_nodes]
|
||||||
@ -302,18 +333,29 @@ class Qube:
|
|||||||
return [new_node.replace(children=children) for new_node in new_nodes]
|
return [new_node.replace(children=children) for new_node in new_nodes]
|
||||||
|
|
||||||
children = tuple(cc for c in self.children for cc in transform(c))
|
children = tuple(cc for c in self.children for cc in transform(c))
|
||||||
return dataclasses.replace(self, children=children)
|
return self.replace(children=children)
|
||||||
|
|
||||||
|
def convert_dtypes(self, converters: dict[str, Callable[[Any], Any]]):
|
||||||
|
def convert(node: Qube) -> Qube:
|
||||||
|
if node.key in converters:
|
||||||
|
converter = converters[node.key]
|
||||||
|
new_node = node.replace(values=QEnum(map(converter, node.values)))
|
||||||
|
return new_node
|
||||||
|
return node
|
||||||
|
|
||||||
|
return self.transform(convert)
|
||||||
|
|
||||||
def select(
|
def select(
|
||||||
self,
|
self,
|
||||||
selection: dict[str, str | list[str]],
|
selection: dict[str, str | list[str] | Callable[[Any], bool]],
|
||||||
mode: Literal["strict", "relaxed"] = "relaxed",
|
mode: Literal["strict", "relaxed"] = "relaxed",
|
||||||
prune=True,
|
prune=True,
|
||||||
consume=False,
|
consume=False,
|
||||||
) -> "Qube":
|
) -> "Qube":
|
||||||
# make all values lists
|
# make all values lists
|
||||||
selection: dict[str, list[str]] = {
|
selection: dict[str, list[str] | Callable[[Any], bool]] = {
|
||||||
k: v if isinstance(v, list) else [v] for k, v in selection.items()
|
k: v if isinstance(v, list | Callable) else [v]
|
||||||
|
for k, v in selection.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def not_none(xs):
|
def not_none(xs):
|
||||||
@ -337,25 +379,27 @@ class Qube:
|
|||||||
if prune and node.children and not new_children:
|
if prune and node.children and not new_children:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return dataclasses.replace(node, children=new_children)
|
return node.replace(children=new_children)
|
||||||
|
|
||||||
# If the key is specified, check if any of the values match
|
# If the key is specified, check if any of the values match
|
||||||
values = QEnum((c for c in selection[node.key] if c in node.values))
|
selection_criteria = selection[node.key]
|
||||||
|
if isinstance(selection_criteria, Callable):
|
||||||
|
values = QEnum((c for c in node.values if selection_criteria(c)))
|
||||||
|
else:
|
||||||
|
values = QEnum((c for c in selection[node.key] if c in node.values))
|
||||||
|
|
||||||
if not values:
|
if not values:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
data = dataclasses.replace(node.data, values=values)
|
|
||||||
if consume:
|
if consume:
|
||||||
selection = {k: v for k, v in selection.items() if k != node.key}
|
selection = {k: v for k, v in selection.items() if k != node.key}
|
||||||
return dataclasses.replace(
|
return node.replace(
|
||||||
node,
|
values=values,
|
||||||
data=data,
|
|
||||||
children=not_none(select(c, selection) for c in node.children),
|
children=not_none(select(c, selection) for c in node.children),
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataclasses.replace(
|
return self.replace(
|
||||||
self, children=not_none(select(c, selection) for c in self.children)
|
children=not_none(select(c, selection) for c in self.children)
|
||||||
)
|
)
|
||||||
|
|
||||||
def span(self, key: str) -> list[str]:
|
def span(self, key: str) -> list[str]:
|
||||||
|
@ -10,7 +10,7 @@ def test_simple():
|
|||||||
├── param=2t, threshold=273.15
|
├── param=2t, threshold=273.15
|
||||||
└── param=tp, threshold=0.1/1/10/100/20/25/5/50
|
└── param=tp, threshold=0.1/1/10/100/20/25/5/50
|
||||||
""")
|
""")
|
||||||
q.print()
|
|
||||||
r = Qube.from_dict(
|
r = Qube.from_dict(
|
||||||
{
|
{
|
||||||
"frequency=6:00:00": {
|
"frequency=6:00:00": {
|
||||||
|
@ -25,3 +25,37 @@ def test_consumption_off():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert q.select({"expver": "0001"}, consume=False) == expected
|
assert q.select({"expver": "0001"}, consume=False) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_input_to_select():
|
||||||
|
q = Qube.from_tree("""
|
||||||
|
root, frequency=6:00:00
|
||||||
|
├── levtype=pl, param=t, levelist=850, threshold=-2/-4/-8/2/4/8
|
||||||
|
└── levtype=sfc
|
||||||
|
├── param=10u/10v, threshold=10/15
|
||||||
|
├── param=2t, threshold=273.15
|
||||||
|
└── param=tp, threshold=0.1/1/10/100/20/25/5/50
|
||||||
|
""").convert_dtypes(
|
||||||
|
{
|
||||||
|
"threshold": float,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
r = q.select(
|
||||||
|
{
|
||||||
|
"threshold": lambda t: t > 5,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert r == Qube.from_tree("""
|
||||||
|
root, frequency=6:00:00
|
||||||
|
├── levtype=pl, param=t, levelist=850, threshold=8
|
||||||
|
└── levtype=sfc
|
||||||
|
├── param=10u/10v, threshold=10/15
|
||||||
|
├── param=2t, threshold=273.15
|
||||||
|
└── param=tp, threshold=10/100/20/25/50
|
||||||
|
""").convert_dtypes(
|
||||||
|
{
|
||||||
|
"threshold": float,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user