Add convert_dtypes and selection with functions
This commit is contained in:
parent
d2f3165fe8
commit
df5360f29a
@ -1,8 +1,9 @@
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence
|
||||
from typing import Any, Iterable, Iterator, Literal, Sequence
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
@ -233,6 +234,36 @@ class Qube:
|
||||
else:
|
||||
yield leaf
|
||||
|
||||
def leaves_with_metadata(
|
||||
self, indices=()
|
||||
) -> Iterable[tuple[dict[str, str], dict[str, str]]]:
|
||||
if self.key == "root":
|
||||
for c in self.children:
|
||||
for leaf in c.leaves_with_metadata(indices=()):
|
||||
yield leaf
|
||||
return
|
||||
|
||||
for index, value in enumerate(self.values):
|
||||
# print(self.key, index, indices, value)
|
||||
# print({k: np.shape(v) for k, v in self.metadata.items()})
|
||||
indexed_metadata = {
|
||||
k: vs[indices + (index,)] for k, vs in self.metadata.items()
|
||||
}
|
||||
indexed_metadata = {
|
||||
k: v.item() if v.shape == () else v for k, v in indexed_metadata.items()
|
||||
}
|
||||
if not self.children:
|
||||
yield {self.key: value}, indexed_metadata
|
||||
|
||||
for child in self.children:
|
||||
for leaf, metadata in child.leaves_with_metadata(
|
||||
indices=indices + (index,)
|
||||
):
|
||||
if self.key != "root":
|
||||
yield {self.key: value, **leaf}, metadata | indexed_metadata
|
||||
else:
|
||||
yield leaf, metadata
|
||||
|
||||
def datacubes(self) -> "Qube":
|
||||
def to_list_of_cubes(node: Qube) -> Iterable[Qube]:
|
||||
if not node.children:
|
||||
@ -241,7 +272,7 @@ class Qube:
|
||||
for c in node.children:
|
||||
# print(c)
|
||||
for sub_cube in to_list_of_cubes(c):
|
||||
yield dataclasses.replace(node, children=[sub_cube])
|
||||
yield node.replace(children=[sub_cube])
|
||||
|
||||
return Qube.root_node((q for c in self.children for q in to_list_of_cubes(c)))
|
||||
|
||||
@ -294,7 +325,7 @@ class Qube:
|
||||
"""
|
||||
|
||||
def transform(node: Qube) -> list[Qube]:
|
||||
children = [cc for c in node.children for cc in transform(c)]
|
||||
children = tuple(sorted(cc for c in node.children for cc in transform(c)))
|
||||
new_nodes = func(node)
|
||||
if isinstance(new_nodes, Qube):
|
||||
new_nodes = [new_nodes]
|
||||
@ -302,18 +333,29 @@ class Qube:
|
||||
return [new_node.replace(children=children) for new_node in new_nodes]
|
||||
|
||||
children = tuple(cc for c in self.children for cc in transform(c))
|
||||
return dataclasses.replace(self, children=children)
|
||||
return self.replace(children=children)
|
||||
|
||||
def convert_dtypes(self, converters: dict[str, Callable[[Any], Any]]):
|
||||
def convert(node: Qube) -> Qube:
|
||||
if node.key in converters:
|
||||
converter = converters[node.key]
|
||||
new_node = node.replace(values=QEnum(map(converter, node.values)))
|
||||
return new_node
|
||||
return node
|
||||
|
||||
return self.transform(convert)
|
||||
|
||||
def select(
|
||||
self,
|
||||
selection: dict[str, str | list[str]],
|
||||
selection: dict[str, str | list[str] | Callable[[Any], bool]],
|
||||
mode: Literal["strict", "relaxed"] = "relaxed",
|
||||
prune=True,
|
||||
consume=False,
|
||||
) -> "Qube":
|
||||
# make all values lists
|
||||
selection: dict[str, list[str]] = {
|
||||
k: v if isinstance(v, list) else [v] for k, v in selection.items()
|
||||
selection: dict[str, list[str] | Callable[[Any], bool]] = {
|
||||
k: v if isinstance(v, list | Callable) else [v]
|
||||
for k, v in selection.items()
|
||||
}
|
||||
|
||||
def not_none(xs):
|
||||
@ -337,25 +379,27 @@ class Qube:
|
||||
if prune and node.children and not new_children:
|
||||
return None
|
||||
|
||||
return dataclasses.replace(node, children=new_children)
|
||||
return node.replace(children=new_children)
|
||||
|
||||
# If the key is specified, check if any of the values match
|
||||
selection_criteria = selection[node.key]
|
||||
if isinstance(selection_criteria, Callable):
|
||||
values = QEnum((c for c in node.values if selection_criteria(c)))
|
||||
else:
|
||||
values = QEnum((c for c in selection[node.key] if c in node.values))
|
||||
|
||||
if not values:
|
||||
return None
|
||||
|
||||
data = dataclasses.replace(node.data, values=values)
|
||||
if consume:
|
||||
selection = {k: v for k, v in selection.items() if k != node.key}
|
||||
return dataclasses.replace(
|
||||
node,
|
||||
data=data,
|
||||
return node.replace(
|
||||
values=values,
|
||||
children=not_none(select(c, selection) for c in node.children),
|
||||
)
|
||||
|
||||
return dataclasses.replace(
|
||||
self, children=not_none(select(c, selection) for c in self.children)
|
||||
return self.replace(
|
||||
children=not_none(select(c, selection) for c in self.children)
|
||||
)
|
||||
|
||||
def span(self, key: str) -> list[str]:
|
||||
|
@ -10,7 +10,7 @@ def test_simple():
|
||||
├── param=2t, threshold=273.15
|
||||
└── param=tp, threshold=0.1/1/10/100/20/25/5/50
|
||||
""")
|
||||
q.print()
|
||||
|
||||
r = Qube.from_dict(
|
||||
{
|
||||
"frequency=6:00:00": {
|
||||
|
@ -25,3 +25,37 @@ def test_consumption_off():
|
||||
}
|
||||
)
|
||||
assert q.select({"expver": "0001"}, consume=False) == expected
|
||||
|
||||
|
||||
def test_function_input_to_select():
|
||||
q = Qube.from_tree("""
|
||||
root, frequency=6:00:00
|
||||
├── levtype=pl, param=t, levelist=850, threshold=-2/-4/-8/2/4/8
|
||||
└── levtype=sfc
|
||||
├── param=10u/10v, threshold=10/15
|
||||
├── param=2t, threshold=273.15
|
||||
└── param=tp, threshold=0.1/1/10/100/20/25/5/50
|
||||
""").convert_dtypes(
|
||||
{
|
||||
"threshold": float,
|
||||
}
|
||||
)
|
||||
|
||||
r = q.select(
|
||||
{
|
||||
"threshold": lambda t: t > 5,
|
||||
}
|
||||
)
|
||||
|
||||
assert r == Qube.from_tree("""
|
||||
root, frequency=6:00:00
|
||||
├── levtype=pl, param=t, levelist=850, threshold=8
|
||||
└── levtype=sfc
|
||||
├── param=10u/10v, threshold=10/15
|
||||
├── param=2t, threshold=273.15
|
||||
└── param=tp, threshold=10/100/20/25/50
|
||||
""").convert_dtypes(
|
||||
{
|
||||
"threshold": float,
|
||||
}
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user