From df5360f29a8e05a84e45d70ea304e3c7bbb61b8c Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 27 Mar 2025 16:02:58 +0000 Subject: [PATCH] Add convert_dtypes and selection with functions --- src/python/qubed/Qube.py | 74 ++++++++++++++++++++++++++++++++-------- tests/test_creation.py | 2 +- tests/test_selection.py | 34 ++++++++++++++++++ 3 files changed, 94 insertions(+), 16 deletions(-) diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index f46f2d5..0b38e60 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -1,8 +1,9 @@ import dataclasses from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass from functools import cached_property -from typing import Any, Callable, Iterable, Iterator, Literal, Sequence +from typing import Any, Iterable, Iterator, Literal, Sequence from frozendict import frozendict @@ -233,6 +234,36 @@ class Qube: else: yield leaf + def leaves_with_metadata( + self, indices=() + ) -> Iterable[tuple[dict[str, str], dict[str, str]]]: + if self.key == "root": + for c in self.children: + for leaf in c.leaves_with_metadata(indices=()): + yield leaf + return + + for index, value in enumerate(self.values): + # print(self.key, index, indices, value) + # print({k: np.shape(v) for k, v in self.metadata.items()}) + indexed_metadata = { + k: vs[indices + (index,)] for k, vs in self.metadata.items() + } + indexed_metadata = { + k: v.item() if v.shape == () else v for k, v in indexed_metadata.items() + } + if not self.children: + yield {self.key: value}, indexed_metadata + + for child in self.children: + for leaf, metadata in child.leaves_with_metadata( + indices=indices + (index,) + ): + if self.key != "root": + yield {self.key: value, **leaf}, metadata | indexed_metadata + else: + yield leaf, metadata + def datacubes(self) -> "Qube": def to_list_of_cubes(node: Qube) -> Iterable[Qube]: if not node.children: @@ -241,7 +272,7 @@ class Qube: for c in node.children: # print(c) for sub_cube in to_list_of_cubes(c): - yield dataclasses.replace(node, children=[sub_cube]) + yield node.replace(children=[sub_cube]) return Qube.root_node((q for c in self.children for q in to_list_of_cubes(c))) @@ -294,7 +325,7 @@ class Qube: """ def transform(node: Qube) -> list[Qube]: - children = [cc for c in node.children for cc in transform(c)] + children = tuple(sorted(cc for c in node.children for cc in transform(c))) new_nodes = func(node) if isinstance(new_nodes, Qube): new_nodes = [new_nodes] @@ -302,18 +333,29 @@ class Qube: return [new_node.replace(children=children) for new_node in new_nodes] children = tuple(cc for c in self.children for cc in transform(c)) - return dataclasses.replace(self, children=children) + return self.replace(children=children) + + def convert_dtypes(self, converters: dict[str, Callable[[Any], Any]]): + def convert(node: Qube) -> Qube: + if node.key in converters: + converter = converters[node.key] + new_node = node.replace(values=QEnum(map(converter, node.values))) + return new_node + return node + + return self.transform(convert) def select( self, - selection: dict[str, str | list[str]], + selection: dict[str, str | list[str] | Callable[[Any], bool]], mode: Literal["strict", "relaxed"] = "relaxed", prune=True, consume=False, ) -> "Qube": # make all values lists - selection: dict[str, list[str]] = { - k: v if isinstance(v, list) else [v] for k, v in selection.items() + selection: dict[str, list[str] | Callable[[Any], bool]] = { + k: v if isinstance(v, list | Callable) else [v] + for k, v in selection.items() } def not_none(xs): @@ -337,25 +379,27 @@ class Qube: if prune and node.children and not new_children: return None - return dataclasses.replace(node, children=new_children) + return node.replace(children=new_children) # If the key is specified, check if any of the values match - values = QEnum((c for c in selection[node.key] if c in node.values)) + selection_criteria = selection[node.key] + if isinstance(selection_criteria, Callable): + values = QEnum((c for c in node.values if selection_criteria(c))) + else: + values = QEnum((c for c in selection[node.key] if c in node.values)) if not values: return None - data = dataclasses.replace(node.data, values=values) if consume: selection = {k: v for k, v in selection.items() if k != node.key} - return dataclasses.replace( - node, - data=data, + return node.replace( + values=values, children=not_none(select(c, selection) for c in node.children), ) - return dataclasses.replace( - self, children=not_none(select(c, selection) for c in self.children) + return self.replace( + children=not_none(select(c, selection) for c in self.children) ) def span(self, key: str) -> list[str]: diff --git a/tests/test_creation.py b/tests/test_creation.py index 77f16da..147485b 100644 --- a/tests/test_creation.py +++ b/tests/test_creation.py @@ -10,7 +10,7 @@ def test_simple(): ├── param=2t, threshold=273.15 └── param=tp, threshold=0.1/1/10/100/20/25/5/50 """) - q.print() + r = Qube.from_dict( { "frequency=6:00:00": { diff --git a/tests/test_selection.py b/tests/test_selection.py index 777d525..6e56f2b 100644 --- a/tests/test_selection.py +++ b/tests/test_selection.py @@ -25,3 +25,37 @@ def test_consumption_off(): } ) assert q.select({"expver": "0001"}, consume=False) == expected + + +def test_function_input_to_select(): + q = Qube.from_tree(""" + root, frequency=6:00:00 + ├── levtype=pl, param=t, levelist=850, threshold=-2/-4/-8/2/4/8 + └── levtype=sfc + ├── param=10u/10v, threshold=10/15 + ├── param=2t, threshold=273.15 + └── param=tp, threshold=0.1/1/10/100/20/25/5/50 + """).convert_dtypes( + { + "threshold": float, + } + ) + + r = q.select( + { + "threshold": lambda t: t > 5, + } + ) + + assert r == Qube.from_tree(""" + root, frequency=6:00:00 + ├── levtype=pl, param=t, levelist=850, threshold=8 + └── levtype=sfc + ├── param=10u/10v, threshold=10/15 + ├── param=2t, threshold=273.15 + └── param=tp, threshold=10/100/20/25/50 + """).convert_dtypes( + { + "threshold": float, + } + )