diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 5dbfea18..5499048e 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -87,6 +87,17 @@ fn test_query_errors_on_invalid_syntax() { .join("\n") )) ); + assert_eq!( + Query::new(language, r#"((identifier) [])"#), + Err(QueryError::Syntax( + 1, + [ + "((identifier) [])", // + " ^", + ] + .join("\n") + )) + ); assert_eq!( Query::new(language, r#"((identifier) (#a)"#), Err(QueryError::Syntax( @@ -367,8 +378,9 @@ fn test_query_matches_with_many_overlapping_results() { function: (identifier) @function) ((identifier) @constant (#match? @constant "[A-Z\\d_]+")) - "# - ).unwrap(); + "#, + ) + .unwrap(); let count = 80; @@ -388,8 +400,13 @@ fn test_query_matches_with_many_overlapping_results() { &[ (0, vec![("method", "foo")]), (1, vec![("function", "bar")]), - (2, vec![("constant", "BAZ")]) - ].iter().cloned().cycle().take(3 * count).collect::>(), + (2, vec![("constant", "BAZ")]), + ] + .iter() + .cloned() + .cycle() + .take(3 * count) + .collect::>(), ); }); } @@ -875,6 +892,122 @@ fn test_query_matches_with_repeated_internal_nodes() { }) } +#[test] +fn test_query_matches_with_simple_alternatives() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + " + (pair + key: [(property_identifier) (string)] @key + value: [(function) @val1 (arrow_function) @val2]) + ", + ) + .unwrap(); + + assert_query_matches( + language, + &query, + " + a = { + b: c, + 'd': e => f, + g: { + h: function i() {}, + 'x': null, + j: _ => k + }, + 'l': function m() {}, + }; + ", + &[ + (0, vec![("key", "'d'"), ("val2", "e => f")]), + (0, vec![("key", "h"), ("val1", "function i() {}")]), + (0, vec![("key", "j"), ("val2", "_ => k")]), + (0, vec![("key", "'l'"), ("val1", "function m() {}")]), + ], + ); + }) +} + +#[test] +fn test_query_matches_with_alternatives_in_repetitions() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (array + [(identifier) (string)] @el + . + ( + "," + . + [(identifier) (string)] @el + )*) + "#, + ) + .unwrap(); + + assert_query_matches( + language, + &query, + " + a = [b, 'c', d, 1, e, 'f', 'g', h]; + ", + &[ + (0, vec![("el", "b"), ("el", "'c'"), ("el", "d")]), + ( + 0, + vec![("el", "e"), ("el", "'f'"), ("el", "'g'"), ("el", "h")], + ), + ], + ); + }) +} + +#[test] +fn test_query_matches_with_alternatives_at_root() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + [ + "if" + "else" + "function" + "throw" + "return" + ] @keyword + "#, + ) + .unwrap(); + + assert_query_matches( + language, + &query, + " + function a(b, c, d) { + if (b) { + return c; + } else { + throw d; + } + } + ", + &[ + (0, vec![("keyword", "function")]), + (0, vec![("keyword", "if")]), + (0, vec![("keyword", "return")]), + (0, vec![("keyword", "else")]), + (0, vec![("keyword", "throw")]), + ], + ); + }) +} + #[test] fn test_query_matches_in_language_with_simple_aliases() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 89b5e4b5..19996066 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -8,6 +8,9 @@ #include "./unicode.h" #include +// #define LOG(...) fprintf(stderr, __VA_ARGS__) +#define LOG(...) + #define MAX_STATE_COUNT 256 #define MAX_CAPTURE_LIST_COUNT 32 #define MAX_STEP_CAPTURE_COUNT 3 @@ -32,9 +35,8 @@ typedef struct { * wildcard symbol, '_'. * - `field` - The field name to match. A zero value means that a field name * was not specified. - * - `capture_id` - An integer representing the name of the capture associated - * with this node in the pattern. A `NONE` value means this node is not - * captured in this pattern. + * - `capture_ids` - An array of integers representing the names of captures + * associated with this node in the pattern, terminated by a `NONE` value. * - `depth` - The depth where this node occurs in the pattern. The root node * of the pattern has depth zero. * - `alternative_index` - The index of a different query step that serves as @@ -49,8 +51,9 @@ typedef struct { bool contains_captures: 1; bool is_pattern_start: 1; bool is_immediate: 1; - bool is_last: 1; - bool is_placeholder: 1; + bool is_last_child: 1; + bool is_pass_through: 1; + bool is_dead_end: 1; bool alternative_is_immediate: 1; } QueryStep; @@ -177,9 +180,6 @@ static const uint16_t NONE = UINT16_MAX; static const TSSymbol WILDCARD_SYMBOL = 0; static const TSSymbol NAMED_WILDCARD_SYMBOL = UINT16_MAX - 1; -// #define LOG(...) fprintf(stderr, __VA_ARGS__) -#define LOG(...) - /********** * Stream **********/ @@ -447,9 +447,10 @@ static QueryStep query_step__new( .capture_ids = {NONE, NONE, NONE}, .alternative_index = NONE, .contains_captures = false, - .is_last = false, + .is_last_child = false, .is_pattern_start = false, - .is_placeholder = false, + .is_pass_through = false, + .is_dead_end = false, .is_immediate = is_immediate, .alternative_is_immediate = false, }; @@ -714,15 +715,60 @@ static TSQueryError ts_query__parse_pattern( uint32_t *capture_count, bool is_immediate ) { - uint16_t starting_step_index = self->steps.size; + uint32_t starting_step_index = self->steps.size; if (stream->next == 0) return TSQueryErrorSyntax; - // Finish the parent S-expression - if (stream->next == ')') { + // Finish the parent S-expression. + if (stream->next == ')' || stream->next == ']') { return PARENT_DONE; } + // An open bracket is the start of an alternation. + else if (stream->next == '[') { + stream_advance(stream); + stream_skip_whitespace(stream); + + // Parse each branch, and add a placeholder step in between the branches. + Array(uint32_t) branch_step_indices = array_new(); + for (;;) { + uint32_t start_index = self->steps.size; + TSQueryError e = ts_query__parse_pattern( + self, + stream, + depth, + capture_count, + is_immediate + ); + + if (e == PARENT_DONE && stream->next == ']' && branch_step_indices.size > 0) { + stream_advance(stream); + break; + } else if (e) { + array_delete(&branch_step_indices); + return e; + } + + array_push(&branch_step_indices, start_index); + array_push(&self->steps, query_step__new(0, depth, false)); + } + array_pop(&self->steps); + + // For all of the branches except for the last one, add the subsequent branch as an + // alternative, and link the end of the branch to the current end of the steps. + for (unsigned i = 0; i < branch_step_indices.size - 1; i++) { + uint32_t step_index = branch_step_indices.contents[i]; + uint32_t next_step_index = branch_step_indices.contents[i + 1]; + QueryStep *start_step = &self->steps.contents[step_index]; + QueryStep *end_step = &self->steps.contents[next_step_index - 1]; + start_step->alternative_index = next_step_index; + end_step->alternative_index = self->steps.size; + end_step->is_dead_end = true; + } + + array_delete(&branch_step_indices); + } + // An open parenthesis can be the start of three possible constructs: // * A grouped sequence // * A predicate @@ -732,7 +778,7 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); // If this parenthesis is followed by a node, then it represents a grouped sequence. - if (stream->next == '(' || stream->next == '"') { + if (stream->next == '(' || stream->next == '"' || stream->next == '[') { bool child_is_immediate = false; for (;;) { if (stream->next == '.') { @@ -747,7 +793,7 @@ static TSQueryError ts_query__parse_pattern( capture_count, child_is_immediate ); - if (e == PARENT_DONE) { + if (e == PARENT_DONE && stream->next == ')') { stream_advance(stream); break; } else if (e) { @@ -828,9 +874,9 @@ static TSQueryError ts_query__parse_pattern( capture_count, child_is_immediate ); - if (e == PARENT_DONE) { + if (e == PARENT_DONE && stream->next == ')') { if (child_is_immediate) { - self->steps.contents[child_start_step_index].is_last = true; + self->steps.contents[child_start_step_index].is_last_child = true; } stream_advance(stream); break; @@ -939,42 +985,54 @@ static TSQueryError ts_query__parse_pattern( for (;;) { QueryStep *step = &self->steps.contents[starting_step_index]; + // Parse the one-or-more operator. if (stream->next == '+') { stream_advance(stream); stream_skip_whitespace(stream); + QueryStep repeat_step = query_step__new(WILDCARD_SYMBOL, depth, false); repeat_step.alternative_index = starting_step_index; - repeat_step.is_placeholder = true; + repeat_step.is_pass_through = true; repeat_step.alternative_is_immediate = true; array_push(&self->steps, repeat_step); } - else if (stream->next == '?') { - stream_advance(stream); - stream_skip_whitespace(stream); - step->alternative_index = self->steps.size; - } - + // Parse the zero-or-more repetition operator. else if (stream->next == '*') { stream_advance(stream); stream_skip_whitespace(stream); + QueryStep repeat_step = query_step__new(WILDCARD_SYMBOL, depth, false); repeat_step.alternative_index = starting_step_index; - repeat_step.is_placeholder = true; + repeat_step.is_pass_through = true; repeat_step.alternative_is_immediate = true; array_push(&self->steps, repeat_step); + + while (step->alternative_index != NONE) { + step = &self->steps.contents[step->alternative_index]; + } + step->alternative_index = self->steps.size; + } + + // Parse the optional operator. + else if (stream->next == '?') { + stream_advance(stream); + stream_skip_whitespace(stream); + + while (step->alternative_index != NONE) { + step = &self->steps.contents[step->alternative_index]; + } step->alternative_index = self->steps.size; } // Parse an '@'-prefixed capture pattern else if (stream->next == '@') { stream_advance(stream); - - // Parse the capture name if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; const char *capture_name = stream->input; stream_scan_identifier(stream); uint32_t length = stream->input - capture_name; + stream_skip_whitespace(stream); // Add the capture id to the first step of the pattern uint16_t capture_id = symbol_table_insert_name( @@ -982,10 +1040,22 @@ static TSQueryError ts_query__parse_pattern( capture_name, length ); - query_step__add_capture(step, capture_id); - (*capture_count)++; - stream_skip_whitespace(stream); + for (;;) { + query_step__add_capture(step, capture_id); + if ( + step->alternative_index != NONE && + step->alternative_index > starting_step_index && + step->alternative_index < self->steps.size + ) { + starting_step_index = step->alternative_index; + step = &self->steps.contents[starting_step_index]; + } else { + break; + } + } + + (*capture_count)++; } // No more suffix modifiers @@ -1062,6 +1132,7 @@ TSQuery *ts_query_new( // If any pattern could not be parsed, then report the error information // and terminate. if (*error_type) { + if (*error_type == PARENT_DONE) *error_type = TSQueryErrorSyntax; *error_offset = stream.input - source; ts_query_delete(self); return NULL; @@ -1086,6 +1157,9 @@ TSQuery *ts_query_new( if (step->symbol == WILDCARD_SYMBOL) { self->wildcard_root_pattern_count++; } + + // If there are alternatives or options at the root of the pattern, + // then add multiple entries to the pattern map. if (step->alternative_index != NONE) { start_step_index = step->alternative_index; } else { @@ -1583,7 +1657,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { if ((step->is_immediate && is_named) || state->seeking_immediate_match) { later_sibling_can_match = false; } - if (step->is_last && can_have_later_siblings) { + if (step->is_last_child && can_have_later_siblings) { node_does_match = false; } if (step->field) { @@ -1705,8 +1779,14 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { QueryState *state = &self->states.contents[j]; QueryStep *next_step = &self->query->steps.contents[state->step_index]; if (next_step->alternative_index != NONE) { + if (next_step->is_dead_end) { + state->step_index = next_step->alternative_index; + j--; + continue; + } + QueryState *copy = ts_query__cursor_copy_state(self, state); - if (next_step->is_placeholder) { + if (next_step->is_pass_through) { state->step_index++; j--; } @@ -1718,10 +1798,11 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { copy->seeking_immediate_match = true; } LOG( - " split state for branch. pattern:%u, step:%u, step:%u\n", + " split state for branch. pattern:%u, step:%u, step:%u, immediate:%d\n", copy->pattern_index, state->step_index, - copy->step_index + copy->step_index, + copy->seeking_immediate_match ); } }