Add convert_dtypes and selection with functions

This commit is contained in:
Tom 2025-03-27 16:02:58 +00:00
parent d2f3165fe8
commit df5360f29a
3 changed files with 94 additions and 16 deletions

View File

@ -1,8 +1,9 @@
import dataclasses
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence
from typing import Any, Iterable, Iterator, Literal, Sequence
from frozendict import frozendict
@ -233,6 +234,36 @@ class Qube:
else:
yield leaf
def leaves_with_metadata(
self, indices=()
) -> Iterable[tuple[dict[str, str], dict[str, str]]]:
if self.key == "root":
for c in self.children:
for leaf in c.leaves_with_metadata(indices=()):
yield leaf
return
for index, value in enumerate(self.values):
# print(self.key, index, indices, value)
# print({k: np.shape(v) for k, v in self.metadata.items()})
indexed_metadata = {
k: vs[indices + (index,)] for k, vs in self.metadata.items()
}
indexed_metadata = {
k: v.item() if v.shape == () else v for k, v in indexed_metadata.items()
}
if not self.children:
yield {self.key: value}, indexed_metadata
for child in self.children:
for leaf, metadata in child.leaves_with_metadata(
indices=indices + (index,)
):
if self.key != "root":
yield {self.key: value, **leaf}, metadata | indexed_metadata
else:
yield leaf, metadata
def datacubes(self) -> "Qube":
def to_list_of_cubes(node: Qube) -> Iterable[Qube]:
if not node.children:
@ -241,7 +272,7 @@ class Qube:
for c in node.children:
# print(c)
for sub_cube in to_list_of_cubes(c):
yield dataclasses.replace(node, children=[sub_cube])
yield node.replace(children=[sub_cube])
return Qube.root_node((q for c in self.children for q in to_list_of_cubes(c)))
@ -294,7 +325,7 @@ class Qube:
"""
def transform(node: Qube) -> list[Qube]:
children = [cc for c in node.children for cc in transform(c)]
children = tuple(sorted(cc for c in node.children for cc in transform(c)))
new_nodes = func(node)
if isinstance(new_nodes, Qube):
new_nodes = [new_nodes]
@ -302,18 +333,29 @@ class Qube:
return [new_node.replace(children=children) for new_node in new_nodes]
children = tuple(cc for c in self.children for cc in transform(c))
return dataclasses.replace(self, children=children)
return self.replace(children=children)
def convert_dtypes(self, converters: dict[str, Callable[[Any], Any]]):
def convert(node: Qube) -> Qube:
if node.key in converters:
converter = converters[node.key]
new_node = node.replace(values=QEnum(map(converter, node.values)))
return new_node
return node
return self.transform(convert)
def select(
self,
selection: dict[str, str | list[str]],
selection: dict[str, str | list[str] | Callable[[Any], bool]],
mode: Literal["strict", "relaxed"] = "relaxed",
prune=True,
consume=False,
) -> "Qube":
# make all values lists
selection: dict[str, list[str]] = {
k: v if isinstance(v, list) else [v] for k, v in selection.items()
selection: dict[str, list[str] | Callable[[Any], bool]] = {
k: v if isinstance(v, list | Callable) else [v]
for k, v in selection.items()
}
def not_none(xs):
@ -337,25 +379,27 @@ class Qube:
if prune and node.children and not new_children:
return None
return dataclasses.replace(node, children=new_children)
return node.replace(children=new_children)
# If the key is specified, check if any of the values match
values = QEnum((c for c in selection[node.key] if c in node.values))
selection_criteria = selection[node.key]
if isinstance(selection_criteria, Callable):
values = QEnum((c for c in node.values if selection_criteria(c)))
else:
values = QEnum((c for c in selection[node.key] if c in node.values))
if not values:
return None
data = dataclasses.replace(node.data, values=values)
if consume:
selection = {k: v for k, v in selection.items() if k != node.key}
return dataclasses.replace(
node,
data=data,
return node.replace(
values=values,
children=not_none(select(c, selection) for c in node.children),
)
return dataclasses.replace(
self, children=not_none(select(c, selection) for c in self.children)
return self.replace(
children=not_none(select(c, selection) for c in self.children)
)
def span(self, key: str) -> list[str]:

View File

@ -10,7 +10,7 @@ def test_simple():
param=2t, threshold=273.15
param=tp, threshold=0.1/1/10/100/20/25/5/50
""")
q.print()
r = Qube.from_dict(
{
"frequency=6:00:00": {

View File

@ -25,3 +25,37 @@ def test_consumption_off():
}
)
assert q.select({"expver": "0001"}, consume=False) == expected
def test_function_input_to_select():
q = Qube.from_tree("""
root, frequency=6:00:00
levtype=pl, param=t, levelist=850, threshold=-2/-4/-8/2/4/8
levtype=sfc
param=10u/10v, threshold=10/15
param=2t, threshold=273.15
param=tp, threshold=0.1/1/10/100/20/25/5/50
""").convert_dtypes(
{
"threshold": float,
}
)
r = q.select(
{
"threshold": lambda t: t > 5,
}
)
assert r == Qube.from_tree("""
root, frequency=6:00:00
levtype=pl, param=t, levelist=850, threshold=8
levtype=sfc
param=10u/10v, threshold=10/15
param=2t, threshold=273.15
param=tp, threshold=10/100/20/25/50
""").convert_dtypes(
{
"threshold": float,
}
)