Add require_match argument to select

This commit is contained in:
Tom 2025-04-30 14:05:36 +02:00
parent b13a06a0cc
commit 07f9a24daa

View File

@ -416,8 +416,8 @@ class Qube:
self, self,
selection: dict[str, str | list[str] | Callable[[Any], bool]], selection: dict[str, str | list[str] | Callable[[Any], bool]],
mode: Literal["strict", "relaxed"] = "relaxed", mode: Literal["strict", "relaxed"] = "relaxed",
prune=True,
consume=False, consume=False,
require_match=False,
) -> Qube: ) -> Qube:
# Find any bare str values and replace them with [str] # Find any bare str values and replace them with [str]
_selection: dict[str, list[str] | Callable[[Any], bool]] = {} _selection: dict[str, list[str] | Callable[[Any], bool]] = {}
@ -433,7 +433,9 @@ class Qube:
return tuple(x for x in xs if x is not None) return tuple(x for x in xs if x is not None)
def select( def select(
node: Qube, selection: dict[str, list[str] | Callable[[Any], bool]] node: Qube,
selection: dict[str, list[str] | Callable[[Any], bool]],
matched: bool,
) -> Qube | None: ) -> Qube | None:
# If this node has no children but there are still parts of the request # If this node has no children but there are still parts of the request
# that have not been consumed, then prune this whole branch # that have not been consumed, then prune this whole branch
@ -475,16 +477,23 @@ class Qube:
if not values: if not values:
return None return None
matched = True
node = node.replace(values=values) node = node.replace(values=values)
if consume: if consume:
selection = {k: v for k, v in selection.items() if k != node.key} selection = {k: v for k, v in selection.items() if k != node.key}
# prune branches with no matches
if require_match and not node.children and not matched:
return None
# Prune nodes that had had all their children pruned # Prune nodes that had had all their children pruned
new_children = not_none(select(c, selection) for c in node.children) new_children = not_none(
select(c, selection, matched) for c in node.children
)
# if node.key == "dataset": print(prune, [(c.key, c.values.values) for c in node.children], [c.key for c in new_children]) # if node.key == "dataset": print(prune, [(c.key, c.values.values) for c in node.children], [c.key for c in new_children])
if prune and node.children and not new_children: if node.children and not new_children:
return None return None
return node.replace( return node.replace(
@ -493,7 +502,9 @@ class Qube:
) )
return self.replace( return self.replace(
children=not_none(select(c, _selection) for c in self.children) children=not_none(
select(c, _selection, matched=False) for c in self.children
)
) )
def span(self, key: str) -> list[str]: def span(self, key: str) -> list[str]: