diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 4293c568..d4ebd884 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2,8 +2,8 @@ use super::helpers::allocations; use super::helpers::fixtures::get_language; use std::fmt::Write; use tree_sitter::{ - Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, QueryPredicate, - QueryPredicateArg, QueryProperty, + Language, Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, + QueryPredicate, QueryPredicateArg, QueryProperty, }; #[test] @@ -163,19 +163,13 @@ fn test_query_matches_with_simple_pattern() { ) .unwrap(); - let source = "function one() { two(); function three() {} }"; - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + assert_query_matches( + language, + &query, + "function one() { two(); function three() {} }", &[ (0, vec![("fn-name", "one")]), - (0, vec![("fn-name", "three")]) + (0, vec![("fn-name", "three")]), ], ); }); @@ -195,7 +189,10 @@ fn test_query_matches_with_multiple_on_same_root() { ) .unwrap(); - let source = " + assert_query_matches( + language, + &query, + " class Person { // the constructor constructor(name) { this.name = name; } @@ -203,30 +200,21 @@ fn test_query_matches_with_multiple_on_same_root() { // the getter getFullName() { return this.name; } } - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + ", &[ ( 0, vec![ ("the-class-name", "Person"), - ("the-method-name", "constructor") - ] + ("the-method-name", "constructor"), + ], ), ( 0, vec![ ("the-class-name", "Person"), - ("the-method-name", "getFullName") - ] + ("the-method-name", "getFullName"), + ], ), ], ); @@ -246,20 +234,14 @@ fn test_query_matches_with_multiple_patterns_different_roots() { ) .unwrap(); - let source = " + assert_query_matches( + language, + &query, + " function f1() { f2(f3()); } - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + ", &[ (0, vec![("fn-def", "f1")]), (1, vec![("fn-ref", "f2")]), @@ -287,21 +269,15 @@ fn test_query_matches_with_multiple_patterns_same_root() { ) .unwrap(); - let source = " + assert_query_matches( + language, + &query, + " a = { b: () => { return c; }, d: function() { return d; } }; - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + ", &[ (1, vec![("method-def", "b")]), (0, vec![("method-def", "d")]), @@ -325,20 +301,14 @@ fn test_query_matches_with_nesting_and_no_fields() { ) .unwrap(); - let source = " + assert_query_matches( + language, + &query, + " [[a]]; [[c, d], [e, f, g, h]]; [[h], [i]]; - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + ", &[ (0, vec![("x1", "c"), ("x2", "d")]), (0, vec![("x1", "e"), ("x2", "f")]), @@ -358,17 +328,11 @@ fn test_query_matches_with_many() { let language = get_language("javascript"); let query = Query::new(language, "(array (identifier) @element)").unwrap(); - let source = "[hello];\n".repeat(50); - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); - - assert_eq!( - collect_matches(matches, &query, source.as_str()), - vec![(0, vec![("element", "hello")]); 50], + assert_query_matches( + language, + &query, + &"[hello];\n".repeat(50), + &vec![(0, vec![("element", "hello")]); 50], ); }); } @@ -385,20 +349,11 @@ fn test_query_matches_capturing_error_nodes() { ) .unwrap(); - let source = "function a(b,, c, d :e:) {}"; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), - &[( - 0, - vec![("the-error", ":e:"), ("the-error-identifier", "e"),] - ),] + assert_query_matches( + language, + &query, + "function a(b,, c, d :e:) {}", + &[(0, vec![("the-error", ":e:"), ("the-error-identifier", "e")])], ); }); } @@ -439,10 +394,6 @@ fn test_query_matches_with_named_wildcard() { fn test_query_matches_with_wildcard_at_the_root() { allocations::record(|| { let language = get_language("javascript"); - let mut cursor = QueryCursor::new(); - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let query = Query::new( language, " @@ -455,13 +406,11 @@ fn test_query_matches_with_wildcard_at_the_root() { ) .unwrap(); - let source = "/* one */ var x; /* two */ function y() {} /* three */ class Z {}"; - - let tree = parser.parse(source, None).unwrap(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - assert_eq!( - collect_matches(matches, &query, source), - &[(0, vec![("doc", "/* two */"), ("name", "y")]),] + assert_query_matches( + language, + &query, + "/* one */ var x; /* two */ function y() {} /* three */ class Z {}", + &[(0, vec![("doc", "/* two */"), ("name", "y")])], ); let query = Query::new( @@ -475,17 +424,15 @@ fn test_query_matches_with_wildcard_at_the_root() { ) .unwrap(); - let source = "['hi', x(true), {y: false}]"; - - let tree = parser.parse(source, None).unwrap(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - assert_eq!( - collect_matches(matches, &query, source), + assert_query_matches( + language, + &query, + "['hi', x(true), {y: false}]", &[ (0, vec![("a", "'hi'")]), (2, vec![("c", "true")]), (3, vec![("d", "false")]), - ] + ], ); }); } @@ -519,16 +466,10 @@ fn test_query_matches_with_immediate_siblings() { ) .unwrap(); - let source = "import a.b.c.d; return [w, [1, y], z]"; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + assert_query_matches( + language, + &query, + "import a.b.c.d; return [w, [1, y], z]", &[ (0, vec![("parent", "a"), ("child", "b")]), (0, vec![("parent", "b"), ("child", "c")]), @@ -536,7 +477,7 @@ fn test_query_matches_with_immediate_siblings() { (0, vec![("parent", "c"), ("child", "d")]), (2, vec![("first-element", "w")]), (2, vec![("first-element", "1")]), - ] + ], ); }); } @@ -564,7 +505,10 @@ fn test_query_matches_with_repeated_leaf_nodes() { ) .unwrap(); - let source = " + assert_query_matches( + language, + &query, + " // one // two a(); @@ -582,16 +526,7 @@ fn test_query_matches_with_repeated_leaf_nodes() { // eight function d() {} } - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + ", &[ ( 0, @@ -599,11 +534,31 @@ fn test_query_matches_with_repeated_leaf_nodes() { ("doc", "// four"), ("doc", "// five"), ("doc", "// six"), - ("name", "B") - ] + ("name", "B"), + ], ), (1, vec![("doc", "// eight"), ("name", "d")]), - ] + ], + ); + }); +} + +#[test] +fn test_query_matches_with_optional_nodes_inside_of_repetitions() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new(language, r#"(array (","? (number) @num)+)"#).unwrap(); + + assert_query_matches( + language, + &query, + r#" + var a = [1, 2, 3, 4] + "#, + &[( + 0, + vec![("num", "1"), ("num", "2"), ("num", "3"), ("num", "4")], + )], ); }); } @@ -625,43 +580,37 @@ fn test_query_matches_with_leading_optional_repeated_leaf_nodes() { ) .unwrap(); - let source = " - function a() { - // one - var b; + assert_query_matches( + language, + &query, + " + function a() { + // one + var b; - function c() {} + function c() {} - // two - // three - var d; + // two + // three + var d; - // four - // five - function e() { + // four + // five + function e() { + } } - } - // six - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + // six + ", &[ (0, vec![("name", "a")]), (0, vec![("name", "c")]), ( 0, - vec![("doc", "// four"), ("doc", "// five"), ("name", "e")] + vec![("doc", "// four"), ("doc", "// five"), ("name", "e")], ), - ] + ], ); }); } @@ -682,37 +631,21 @@ fn test_query_matches_with_optional_nodes() { ) .unwrap(); - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); + assert_query_matches(language, &query, "class A {}", &[(0, vec![("class", "A")])]); - let source = " - class A {} - "; - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), - &[(0, vec![("class", "A")]),] - ); - - let source = " + assert_query_matches( + language, + &query, + " class A {} class B extends C {} class D extends (E.F) {} - "; - let tree = parser.parse(source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + ", &[ (0, vec![("class", "A")]), (0, vec![("class", "B"), ("superclass", "C")]), (0, vec![("class", "D")]), - ] + ], ); }); } @@ -721,10 +654,6 @@ fn test_query_matches_with_optional_nodes() { fn test_query_matches_with_repeated_internal_nodes() { allocations::record(|| { let language = get_language("javascript"); - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let mut cursor = QueryCursor::new(); - let query = Query::new( language, " @@ -735,18 +664,18 @@ fn test_query_matches_with_repeated_internal_nodes() { ", ) .unwrap(); - let source = " + + assert_query_matches( + language, + &query, + " class A { @c @d e() {} } - "; - let tree = parser.parse(source, None).unwrap(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - assert_eq!( - collect_matches(matches, &query, source), - &[(0, vec![("deco", "c"), ("deco", "d"), ("name", "e")]),] + ", + &[(0, vec![("deco", "c"), ("deco", "d"), ("name", "e")])], ); }) } @@ -760,20 +689,16 @@ fn test_query_matches_in_language_with_simple_aliases() { // tag names, script tag names, and style tag names. All of // these tokens are aliased to `tag_name`. let query = Query::new(language, "(tag_name) @tag").unwrap(); - let source = " + + assert_query_matches( + language, + &query, + "
-
"; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); - - assert_eq!( - collect_matches(matches, &query, source), + + ", &[ (0, vec![("tag", "div")]), (0, vec![("tag", "script")]), @@ -789,6 +714,8 @@ fn test_query_matches_in_language_with_simple_aliases() { #[test] fn test_query_matches_with_different_tokens_with_the_same_string_value() { allocations::record(|| { + // In Rust, there are two '<' tokens: one for the binary operator, + // and one with higher precedence for generics. let language = get_language("rust"); let query = Query::new( language, @@ -799,24 +726,16 @@ fn test_query_matches_with_different_tokens_with_the_same_string_value() { ) .unwrap(); - // In Rust, there are two '<' tokens: one for the binary operator, - // and one with higher precedence for generics. - let source = "const A: B = d < e || f > g;"; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + assert_query_matches( + language, + &query, + "const A: B = d < e || f > g;", &[ (0, vec![("less", "<")]), (1, vec![("greater", ">")]), (0, vec![("less", "<")]), (1, vec![("greater", ">")]), - ] + ], ); }); } @@ -866,20 +785,14 @@ fn test_query_matches_with_anonymous_tokens() { ) .unwrap(); - let source = "foo(a && b);"; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); - - assert_eq!( - collect_matches(matches, &query, source), + assert_query_matches( + language, + &query, + "foo(a && b);", &[ (1, vec![("operator", "&&")]), (0, vec![("punctuation", ";")]), - ] + ], ); }); } @@ -1772,6 +1685,20 @@ fn test_query_disable_pattern() { }); } +fn assert_query_matches( + language: Language, + query: &Query, + source: &str, + expected: &[(usize, Vec<(&str, &str)>)], +) { + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + assert_eq!(collect_matches(matches, &query, source), expected); +} + fn collect_matches<'a>( matches: impl Iterator>, query: &'a Query, diff --git a/lib/src/query.c b/lib/src/query.c index f33b646d..29c9e837 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -35,7 +35,8 @@ typedef struct { * captured in this pattern. * - `depth` - The depth where this node occurs in the pattern. The root node * of the pattern has depth zero. - * - `alternative_index` - TODO doc + * - `alternative_index` - The index of a different query step that serves as + * an alternative to this step. */ typedef struct { TSSymbol symbol; @@ -87,7 +88,25 @@ typedef struct { * QueryState - The state of an in-progress match of a particular pattern * in a query. While executing, a `TSQueryCursor` must keep track of a number * of possible in-progress matches. Each of those possible matches is - * represented as one of these states. + * represented as one of these states. Fields: + * - `id` - A numeric id that is exposed to the public API. This allows the + * caller to remove a given match, preventing any more of its captures + * from being returned. + * - `start_depth` - The depth in the tree where the first step of the state's + * pattern was matched. + * - `pattern_index` - The pattern that the state is matching. + * - `consumed_capture_count` - The number of captures from this match that + * have already been returned. + * - `capture_list_id` - A numeric id that can be used to retrieve the state's + * list of captures from the `CaptureListPool`. + * - `seeking_immediate_match` - A flag that indicates that the state's next + * step must be matched by the very next sibling. This is used when + * processing repetitions. + * - `skipped_trailing_optional` - A flag that indicates that there is an + * optional node at the end of this state's pattern, and this state did + * *not* match that node. In order to obey the 'longest-match' rule, this + * match should not be returned until it is clear that there can be no + * longer match. */ typedef struct { uint32_t id; @@ -689,7 +708,7 @@ static TSQueryError ts_query__parse_pattern( stream_advance(stream); stream_skip_whitespace(stream); - // Parse a nested list, which represents a pattern followed by + // At the top-level, a nested list represents one root pattern followed by // zero-or-more predicates. if (stream->next == '(' && depth == 0) { TSQueryError e = ts_query__parse_pattern(self, stream, 0, capture_count, is_immediate); @@ -709,65 +728,94 @@ static TSQueryError ts_query__parse_pattern( } } - TSSymbol symbol; + // When nested inside of a larger pattern, a nested list just represents + // multiple sibling nodes which are grouped, possibly so that a postfix + // operator can be applied to the group. + else if (depth > 0 && (stream->next == '(' || stream->next == '"' )) { + bool child_is_immediate = false; + for (;;) { + if (stream->next == '.') { + child_is_immediate = true; + stream_advance(stream); + stream_skip_whitespace(stream); + } + TSQueryError e = ts_query__parse_pattern( + self, + stream, + depth, + capture_count, + child_is_immediate + ); + if (e == PARENT_DONE) { + stream_advance(stream); + break; + } else if (e) { + return e; + } - // Parse the wildcard symbol - if (stream->next == '*') { - symbol = depth > 0 ? NAMED_WILDCARD_SYMBOL : WILDCARD_SYMBOL; - stream_advance(stream); - } - - // Parse a normal node name - else if (stream_is_ident_start(stream)) { - const char *node_name = stream->input; - stream_scan_identifier(stream); - uint32_t length = stream->input - node_name; - symbol = ts_language_symbol_for_name( - self->language, - node_name, - length, - true - ); - if (!symbol) { - stream_reset(stream, node_name); - return TSQueryErrorNodeType; + child_is_immediate = false; } } else { - return TSQueryErrorSyntax; - } + TSSymbol symbol; - // Add a step for the node. - array_push(&self->steps, query_step__new(symbol, depth, is_immediate)); - - // Parse the child patterns - stream_skip_whitespace(stream); - bool child_is_immediate = false; - uint16_t child_start_step_index = self->steps.size; - for (;;) { - if (stream->next == '.') { - child_is_immediate = true; + // Parse the wildcard symbol + if (stream->next == '*') { + symbol = depth > 0 ? NAMED_WILDCARD_SYMBOL : WILDCARD_SYMBOL; stream_advance(stream); - stream_skip_whitespace(stream); } - TSQueryError e = ts_query__parse_pattern( - self, - stream, - depth + 1, - capture_count, - child_is_immediate - ); - if (e == PARENT_DONE) { - if (child_is_immediate) { - self->steps.contents[child_start_step_index].is_last = true; + // Parse a normal node name + else if (stream_is_ident_start(stream)) { + const char *node_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - node_name; + symbol = ts_language_symbol_for_name( + self->language, + node_name, + length, + true + ); + if (!symbol) { + stream_reset(stream, node_name); + return TSQueryErrorNodeType; } - stream_advance(stream); - break; - } else if (e) { - return e; + } else { + return TSQueryErrorSyntax; } - child_is_immediate = false; + // Add a step for the node. + array_push(&self->steps, query_step__new(symbol, depth, is_immediate)); + + // Parse the child patterns + stream_skip_whitespace(stream); + bool child_is_immediate = false; + uint16_t child_start_step_index = self->steps.size; + for (;;) { + if (stream->next == '.') { + child_is_immediate = true; + stream_advance(stream); + stream_skip_whitespace(stream); + } + + TSQueryError e = ts_query__parse_pattern( + self, + stream, + depth + 1, + capture_count, + child_is_immediate + ); + if (e == PARENT_DONE) { + if (child_is_immediate) { + self->steps.contents[child_start_step_index].is_last = true; + } + stream_advance(stream); + break; + } else if (e) { + return e; + } + + child_is_immediate = false; + } } } @@ -1577,6 +1625,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { // an interative process. unsigned start_index = state - self->states.contents; unsigned end_index = start_index + 1; + bool is_alternative = false; for (unsigned j = start_index; j < end_index; j++) { QueryState *state = &self->states.contents[j]; QueryStep *next_step = &self->query->steps.contents[state->step_index]; @@ -1600,14 +1649,19 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { copy->step_index ); } - } else if (next_step->depth == PATTERN_DONE_MARKER && j > start_index) { - state->skipped_trailing_optional = true; } + + if ( + (next_step->alternative_index != NONE || is_alternative) && + next_step->depth == PATTERN_DONE_MARKER + ) state->skipped_trailing_optional = true; + is_alternative = true; } } for (unsigned i = 0; i < self->states.size; i++) { QueryState *state = &self->states.contents[i]; + bool did_remove = false; // Enfore the longest-match criteria. When a query pattern contains optional or // repeated nodes, this is necesssary to avoid multiple redundant states, where @@ -1636,7 +1690,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { if (right_contains_left) { capture_list_pool_release(&self->capture_list_pool, state->capture_list_id); array_erase(&self->states, i); - i--; + did_remove = true; j--; break; } else if (left_contains_right) { @@ -1650,16 +1704,18 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { // If there the state is at the end of its pattern, remove it from the list // of in-progress states and add it to the list of finished states. - QueryStep *next_step = &self->query->steps.contents[state->step_index]; - if (next_step->depth == PATTERN_DONE_MARKER) { - if (state->skipped_trailing_optional) { - LOG(" defer finishing pattern %u\n", state->pattern_index); - } else { - LOG(" finish pattern %u\n", state->pattern_index); - state->id = self->next_state_id++; - array_push(&self->finished_states, *state); - array_erase(&self->states, i); - i--; + if (!did_remove) { + QueryStep *next_step = &self->query->steps.contents[state->step_index]; + if (next_step->depth == PATTERN_DONE_MARKER) { + if (state->skipped_trailing_optional) { + LOG(" defer finishing pattern %u\n", state->pattern_index); + } else { + LOG(" finish pattern %u\n", state->pattern_index); + state->id = self->next_state_id++; + array_push(&self->finished_states, *state); + array_erase(&self->states, state - self->states.contents); + i--; + } } } }