From 07f9a24daa736e4f6b54976e4a5636285af90ade Mon Sep 17 00:00:00 2001 From: Tom Date: Wed, 30 Apr 2025 14:05:36 +0200 Subject: [PATCH] Add require_match argument to select --- src/python/qubed/Qube.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/python/qubed/Qube.py b/src/python/qubed/Qube.py index 0ff9966..e5cbaf8 100644 --- a/src/python/qubed/Qube.py +++ b/src/python/qubed/Qube.py @@ -416,8 +416,8 @@ class Qube: self, selection: dict[str, str | list[str] | Callable[[Any], bool]], mode: Literal["strict", "relaxed"] = "relaxed", - prune=True, consume=False, + require_match=False, ) -> Qube: # Find any bare str values and replace them with [str] _selection: dict[str, list[str] | Callable[[Any], bool]] = {} @@ -433,7 +433,9 @@ class Qube: return tuple(x for x in xs if x is not None) def select( - node: Qube, selection: dict[str, list[str] | Callable[[Any], bool]] + node: Qube, + selection: dict[str, list[str] | Callable[[Any], bool]], + matched: bool, ) -> Qube | None: # If this node has no children but there are still parts of the request # that have not been consumed, then prune this whole branch @@ -475,16 +477,23 @@ class Qube: if not values: return None + matched = True node = node.replace(values=values) if consume: selection = {k: v for k, v in selection.items() if k != node.key} + # prune branches with no matches + if require_match and not node.children and not matched: + return None + # Prune nodes that had had all their children pruned - new_children = not_none(select(c, selection) for c in node.children) + new_children = not_none( + select(c, selection, matched) for c in node.children + ) # if node.key == "dataset": print(prune, [(c.key, c.values.values) for c in node.children], [c.key for c in new_children]) - if prune and node.children and not new_children: + if node.children and not new_children: return None return node.replace( @@ -493,7 +502,9 @@ class Qube: ) return self.replace( - children=not_none(select(c, _selection) for c in self.children) + children=not_none( + select(c, _selection, matched=False) for c in self.children + ) ) def span(self, key: str) -> list[str]: