From bc0ef5f3736453d615e7ee1ed9e394d0631e8b4f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 12 Mar 2021 14:38:02 -0800 Subject: [PATCH] Add negated field patterns to queries --- cli/src/tests/query_test.rs | 97 +++++++++++++++++++++++++++++ lib/src/query.c | 118 ++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 6a0e7075..a64e1e91 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -132,6 +132,18 @@ fn test_query_errors_on_invalid_syntax() { .join("\n") ); + // Need a field name after a negated field operator + assert_eq!( + Query::new(language, r#"(statement_block ! (if_statement))"#) + .unwrap_err() + .message, + [ + r#"(statement_block ! (if_statement))"#, + r#" ^"# + ] + .join("\n") + ); + // tree-sitter/tree-sitter/issues/968 assert_eq!( Query::new(get_language("c"), r#"(parameter_list [ ")" @foo)"#) @@ -201,6 +213,26 @@ fn test_query_errors_on_invalid_symbols() { message: "conditioning".to_string() } ); + assert_eq!( + Query::new(language, "(if_statement !alternativ)").unwrap_err(), + QueryError { + row: 0, + offset: 15, + column: 15, + kind: QueryErrorKind::Field, + message: "alternativ".to_string() + } + ); + assert_eq!( + Query::new(language, "(if_statement !alternatives)").unwrap_err(), + QueryError { + row: 0, + offset: 15, + column: 15, + kind: QueryErrorKind::Field, + message: "alternatives".to_string() + } + ); }); } @@ -894,6 +926,71 @@ fn test_query_matches_with_last_named_child() { }); } +#[test] +fn test_query_matches_with_negated_fields() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + " + (import_specifier + !alias + name: (identifier) @import_name) + + (export_specifier + !alias + name: (identifier) @export_name) + + (export_statement + !decorator + !source + (_) @exported) + + ; This negated field list is an extension of a previous + ; negated field list. The order of the children and negated + ; fields doesn't matter. + (export_statement + !decorator + !source + (_) @exported_expr + !declaration) + + ; This negated field list is a prefix of a previous + ; negated field list. + (export_statement + !decorator + (_) @export_child .) + ", + ) + .unwrap(); + assert_query_matches( + language, + &query, + " + import {a as b, c} from 'p1'; + export {g, h as i} from 'p2'; + + @foo + export default 1; + + export var j = 1; + + export default k; + ", + &[ + (0, vec![("import_name", "c")]), + (1, vec![("export_name", "g")]), + (4, vec![("export_child", "'p2'")]), + (2, vec![("exported", "var j = 1;")]), + (4, vec![("export_child", "var j = 1;")]), + (2, vec![("exported", "k")]), + (3, vec![("exported_expr", "k")]), + (4, vec![("export_child", "k")]), + ], + ); + }); +} + #[test] fn test_query_matches_with_repeated_leaf_nodes() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index f3342908..710a17b8 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -16,6 +16,7 @@ #define MAX_STEP_CAPTURE_COUNT 3 #define MAX_STATE_PREDECESSOR_COUNT 100 #define MAX_ANALYSIS_STATE_DEPTH 12 +#define MAX_NEGATED_FIELD_COUNT 8 /* * Stream - A sequence of unicode characters derived from a UTF8 string. @@ -75,6 +76,7 @@ typedef struct { uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT]; uint16_t depth; uint16_t alternative_index; + uint16_t negated_field_list_id; bool contains_captures: 1; bool is_immediate: 1; bool is_last_child: 1; @@ -239,6 +241,7 @@ struct TSQuery { Array(TSQueryPredicateStep) predicate_steps; Array(QueryPattern) patterns; Array(StepOffset) step_offsets; + Array(TSFieldId) negated_fields; Array(char) string_buffer; const TSLanguage *language; uint16_t wildcard_root_pattern_count; @@ -480,6 +483,7 @@ static QueryStep query_step__new( .field = 0, .capture_ids = {NONE, NONE, NONE}, .alternative_index = NONE, + .negated_field_list_id = 0, .contains_captures = false, .is_last_child = false, .is_pass_through = false, @@ -1366,6 +1370,58 @@ static void ts_query__finalize_steps(TSQuery *self) { } } +static void ts_query__add_negated_fields( + TSQuery *self, + uint16_t step_index, + TSFieldId *field_ids, + uint16_t field_count +) { + QueryStep *step = &self->steps.contents[step_index]; + + // The negated field array stores a list of field lists, separated by zeros. + // Try to find the start index of an existing list that matches this new list. + bool failed_match = false; + unsigned match_count = 0; + unsigned start_i = 0; + for (unsigned i = 0; i < self->negated_fields.size; i++) { + TSFieldId existing_field_id = self->negated_fields.contents[i]; + + // At each zero value, terminate the match attempt. If we've exactly + // matched the new field list, then reuse this index. Otherwise, + // start over the matching process. + if (existing_field_id == 0) { + if (match_count == field_count) { + step->negated_field_list_id = start_i; + return; + } else { + start_i = i + 1; + match_count = 0; + failed_match = false; + } + } + + // If the existing list matches our new list so far, then advance + // to the next element of the new list. + else if ( + match_count < field_count && + existing_field_id == field_ids[match_count] && + !failed_match + ) { + match_count++; + } + + // Otherwise, this existing list has failed to match. + else { + match_count = 0; + failed_match = true; + } + } + + step->negated_field_list_id = self->negated_fields.size; + array_extend(&self->negated_fields, field_count, field_ids); + array_push(&self->negated_fields, 0); +} + static TSQueryError ts_query__parse_string_literal( TSQuery *self, Stream *stream @@ -1716,7 +1772,39 @@ static TSQueryError ts_query__parse_pattern( // Parse the child patterns bool child_is_immediate = false; uint16_t last_child_step_index = 0; + uint16_t negated_field_count = 0; + TSFieldId negated_field_ids[MAX_NEGATED_FIELD_COUNT]; for (;;) { + // Parse a negated field assertion + if (stream->next == '!') { + stream_advance(stream); + stream_skip_whitespace(stream); + if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; + const char *field_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - field_name; + stream_skip_whitespace(stream); + + TSFieldId field_id = ts_language_field_id_for_name( + self->language, + field_name, + length + ); + if (!field_id) { + stream->input = field_name; + return TSQueryErrorField; + } + + // Keep the field ids sorted. + if (negated_field_count < MAX_NEGATED_FIELD_COUNT) { + negated_field_ids[negated_field_count] = field_id; + negated_field_count++; + } + + continue; + } + + // Parse a sibling anchor if (stream->next == '.') { child_is_immediate = true; stream_advance(stream); @@ -1737,6 +1825,16 @@ static TSQueryError ts_query__parse_pattern( } self->steps.contents[last_child_step_index].is_last_child = true; } + + if (negated_field_count) { + ts_query__add_negated_fields( + self, + starting_step_index, + negated_field_ids, + negated_field_count + ); + } + stream_advance(stream); break; } else if (e) { @@ -1945,10 +2043,13 @@ TSQuery *ts_query_new( .patterns = array_new(), .step_offsets = array_new(), .string_buffer = array_new(), + .negated_fields = array_new(), .wildcard_root_pattern_count = 0, .language = language, }; + array_push(&self->negated_fields, 0); + // Parse all of the S-expressions in the given string. Stream stream = stream_new(source, source_len); stream_skip_whitespace(&stream); @@ -2033,6 +2134,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->patterns); array_delete(&self->step_offsets); array_delete(&self->string_buffer); + array_delete(&self->negated_fields); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); ts_free(self); @@ -2700,6 +2802,22 @@ static inline bool ts_query_cursor__advance( } } + if (step->negated_field_list_id) { + TSFieldId *negated_field_ids = &self->query->negated_fields.contents[step->negated_field_list_id]; + for (;;) { + TSFieldId negated_field_id = *negated_field_ids; + if (negated_field_id) { + negated_field_ids++; + if (ts_node_child_by_field_id(node, negated_field_id).id) { + node_does_match = false; + break; + } + } else { + break; + } + } + } + // Remove states immediately if it is ever clear that they cannot match. if (!node_does_match) { if (!later_sibling_can_match) {