From 1e81a1b67f5aa8b17899186142a26eb422886065 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Fri, 17 Mar 2023 14:22:20 +0000 Subject: [PATCH] feat(lib): add ts_query_cursor_set_max_start_depth query API This allows configuring cursors from traversing too deep into a tree. --- cli/src/tests/query_test.rs | 97 +++++++++++++++++++++++++++++++++++ lib/binding_rust/bindings.rs | 3 ++ lib/binding_rust/lib.rs | 8 +++ lib/include/tree_sitter/api.h | 8 +++ lib/src/query.c | 36 ++++++++++--- 5 files changed, 144 insertions(+), 8 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 7d01c26e..4743bf9e 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -4469,6 +4469,103 @@ fn test_capture_quantifiers() { }); } +#[test] +fn test_query_max_start_depth() { + struct Row { + description: &'static str, + pattern: &'static str, + depth: u32, + matches: &'static [(usize, &'static [(&'static str, &'static str)])], + } + + let source = r#" + if (a1 && a2) { + if (b1 && b2) { } + if (c) { } + } + if (d) { + if (e1 && e2) { } + if (f) { } + } + "#; + + let rows = &[ + Row { + description: "depth 0: match none", + depth: 0, + pattern: r#" + (if_statement) @capture + "#, + matches: &[] + }, + Row { + description: "depth 1: match 2 if statements at the top level", + depth: 1, + pattern: r#" + (if_statement) @capture + "#, + matches : &[ + (0, &[("capture", "if (a1 && a2) {\n if (b1 && b2) { }\n if (c) { }\n }")]), + (0, &[("capture", "if (d) {\n if (e1 && e2) { }\n if (f) { }\n }")]) + ] + }, + Row { + description: "depth 1 with deep pattern: match the only the first if statement", + depth: 1, + pattern: r#" + (if_statement + condition: (parenthesized_expression + (binary_expression) + ) + ) @capture + "#, + matches: &[ + (0, &[("capture", "if (a1 && a2) {\n if (b1 && b2) { }\n if (c) { }\n }")]), + ] + }, + Row { + description: "depth 3 with deep pattern: match all if statements with a binexpr condition", + depth: 3, + pattern: r#" + (if_statement + condition: (parenthesized_expression + (binary_expression) + ) + ) @capture + "#, + matches: &[ + (0, &[("capture", "if (a1 && a2) {\n if (b1 && b2) { }\n if (c) { }\n }")]), + (0, &[("capture", "if (b1 && b2) { }")]), + (0, &[("capture", "if (e1 && e2) { }")]) + ] + }, + ]; + + allocations::record(|| { + let language = get_language("c"); + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + for row in rows.iter() { + eprintln!(" query example: {:?}", row.description); + + let query = Query::new(language, row.pattern).unwrap(); + cursor.set_max_start_depth(row.depth); + + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); + let expected = row + .matches + .iter() + .map(|x| (x.0, x.1.to_vec())) + .collect::>(); + + assert_eq!(collect_matches(matches, &query, source), expected); + } + }); +} + fn assert_query_matches( language: Language, query: &Query, diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 863b1df5..158d1ba1 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -565,6 +565,9 @@ extern "C" { extern "C" { pub fn ts_query_cursor_set_point_range(arg1: *mut TSQueryCursor, arg2: TSPoint, arg3: TSPoint); } +extern "C" { + pub fn ts_query_cursor_set_max_start_depth(arg1: *mut TSQueryCursor, arg2: u32); +} extern "C" { #[doc = " Advance to the next match of the currently running query.\n\n If there is a match, write it to `*match` and return `true`.\n Otherwise, return `false`."] pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 9d470457..87294a5d 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1932,6 +1932,14 @@ impl QueryCursor { } self } + + #[doc(alias = "ts_query_cursor_set_max_start_depth")] + pub fn set_max_start_depth(&mut self, max_start_depth: u32) -> &mut Self { + unsafe { + ffi::ts_query_cursor_set_max_start_depth(self.ptr.as_ptr(), max_start_depth); + } + self + } } impl<'a, 'tree> QueryMatch<'a, 'tree> { diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index edc1c36a..9dc058e8 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -892,6 +892,14 @@ bool ts_query_cursor_next_capture( uint32_t *capture_index ); +/** + * Set the maximum start depth for a cursor. + * + * This prevents cursors from exploring children nodes at a certain depth. + * Note if a pattern includes many children, then they will still be checked. + */ +void ts_query_cursor_set_max_start_depth(TSQueryCursor *, uint32_t); + /**********************/ /* Section - Language */ /**********************/ diff --git a/lib/src/query.c b/lib/src/query.c index da7a4166..dc6ab784 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -305,6 +305,7 @@ struct TSQueryCursor { Array(QueryState) finished_states; CaptureListPool capture_list_pool; uint32_t depth; + uint32_t max_start_depth; uint32_t start_byte; uint32_t end_byte; TSPoint start_point; @@ -2976,6 +2977,7 @@ TSQueryCursor *ts_query_cursor_new(void) { .end_byte = UINT32_MAX, .start_point = {0, 0}, .end_point = POINT_MAX, + .max_start_depth = UINT32_MAX, }; array_reserve(&self->states, 8); array_reserve(&self->finished_states, 8); @@ -3346,9 +3348,15 @@ static QueryState *ts_query_cursor__copy_state( return &self->states.contents[state_index + 1]; } -static inline bool ts_query_cursor__should_descend_outside_of_range( - TSQueryCursor *self +static inline bool ts_query_cursor__should_descend( + TSQueryCursor *self, + bool node_intersects_range ) { + + if (node_intersects_range && self->depth < self->max_start_depth) { + return true; + } + // If there are in-progress matches whose remaining steps occur // deeper in the tree, then descend. for (unsigned i = 0; i < self->states.size; i++) { @@ -3362,6 +3370,10 @@ static inline bool ts_query_cursor__should_descend_outside_of_range( } } + if (self->depth >= self->max_start_depth) { + return false; + } + // If the current node is hidden, then a non-rooted pattern might match // one if its roots inside of this node, and match another of its roots // as part of a sibling node, so we may need to descend. @@ -3555,12 +3567,14 @@ static inline bool ts_query_cursor__advance( // If this node matches the first step of the pattern, then add a new // state at the start of this pattern. QueryStep *step = &self->query->steps.contents[pattern->step_index]; + uint32_t start_depth = self->depth - step->depth; if ( (pattern->is_rooted ? node_intersects_range : (parent_intersects_range && !parent_is_error)) && (!step->field || field_id == step->field) && - (!step->supertype_symbol || supertype_count > 0) + (!step->supertype_symbol || supertype_count > 0) && + (start_depth <= self->max_start_depth) ) { ts_query_cursor__add_state(self, pattern); } @@ -3573,6 +3587,7 @@ static inline bool ts_query_cursor__advance( PatternEntry *pattern = &self->query->pattern_map.contents[i]; QueryStep *step = &self->query->steps.contents[pattern->step_index]; + uint32_t start_depth = self->depth - step->depth; do { // If this node matches the first step of the pattern, then add a new // state at the start of this pattern. @@ -3580,7 +3595,8 @@ static inline bool ts_query_cursor__advance( (pattern->is_rooted ? node_intersects_range : (parent_intersects_range && !parent_is_error)) && - (!step->field || field_id == step->field) + (!step->field || field_id == step->field) && + (start_depth <= self->max_start_depth) ) { ts_query_cursor__add_state(self, pattern); } @@ -3881,10 +3897,7 @@ static inline bool ts_query_cursor__advance( } } - bool should_descend = - node_intersects_range || - ts_query_cursor__should_descend_outside_of_range(self); - if (should_descend) { + if (ts_query_cursor__should_descend(self, node_intersects_range)) { switch (ts_tree_cursor_goto_first_child_internal(&self->cursor)) { case TreeCursorStepVisible: self->depth++; @@ -4075,4 +4088,11 @@ bool ts_query_cursor_next_capture( } } +void ts_query_cursor_set_max_start_depth( + TSQueryCursor *self, + uint32_t max_start_depth +) { + self->max_start_depth = max_start_depth; +} + #undef LOG