From 4301110c126b8fabe45a00b20ce965d4043910d8 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 20 Aug 2020 13:06:38 -0700 Subject: [PATCH] query: Indicate specific step that's impossible --- cli/src/tests/query_test.rs | 69 ++++++++++++----------- lib/binding_rust/bindings.rs | 7 +-- lib/binding_rust/lib.rs | 12 ++-- lib/include/tree_sitter/api.h | 6 +- lib/src/query.c | 100 ++++++++++++++++++++-------------- 5 files changed, 103 insertions(+), 91 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index a18c3a8b..1e4ea8cc 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -197,7 +197,11 @@ fn test_query_errors_on_impossible_patterns() { ), Err(QueryError::Pattern( 1, - "(binary_expression left: (identifier) left: (identifier))\n^".to_string(), + [ + "(binary_expression left: (identifier) left: (identifier))", + " ^" + ] + .join("\n"), )) ); @@ -210,7 +214,11 @@ fn test_query_errors_on_impossible_patterns() { Query::new(js_lang, "(function_declaration name: (statement_block))"), Err(QueryError::Pattern( 1, - "(function_declaration name: (statement_block))\n^".to_string(), + [ + "(function_declaration name: (statement_block))", + " ^", + ] + .join("\n") )) ); @@ -219,7 +227,11 @@ fn test_query_errors_on_impossible_patterns() { Query::new(rb_lang, "(call receiver:(binary))"), Err(QueryError::Pattern( 1, - "(call receiver:(binary))\n^".to_string(), + [ + "(call receiver:(binary))", // + " ^", + ] + .join("\n") )) ); }); @@ -2307,55 +2319,52 @@ fn test_query_alternative_predicate_prefix() { } #[test] -fn test_query_is_definite() { +fn test_query_step_is_definite() { struct Row { language: Language, pattern: &'static str, - results_by_symbol: &'static [(&'static str, bool)], + results_by_substring: &'static [(&'static str, bool)], } let rows = &[ Row { language: get_language("python"), pattern: r#"(expression_statement (string))"#, - results_by_symbol: &[("expression_statement", false), ("string", false)], + results_by_substring: &[("expression_statement", false), ("string", false)], }, Row { language: get_language("javascript"), pattern: r#"(expression_statement (string))"#, - results_by_symbol: &[ - ("expression_statement", false), - ("string", false), // string - ], + results_by_substring: &[("expression_statement", false), ("string", false)], }, Row { language: get_language("javascript"), pattern: r#"(object "{" "}")"#, - results_by_symbol: &[("object", false), ("{", true), ("}", true)], + results_by_substring: &[("object", false), ("{", true), ("}", true)], }, Row { language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, - results_by_symbol: &[("pair", false), ("property_identifier", false), (":", true)], + results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)], }, Row { language: get_language("javascript"), pattern: r#"(object "{" (_) "}")"#, - results_by_symbol: &[("object", false), ("{", false), ("", false), ("}", true)], + results_by_substring: &[("object", false), ("{", false), ("", false), ("}", true)], }, Row { language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, - results_by_symbol: &[ + results_by_substring: &[ ("binary_expression", false), - ("identifier", false), - ("", true), + ("(identifier)", false), + ("(_)", true), ], }, Row { language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, - results_by_symbol: &[ + results_by_substring: &[ ("function_declaration", false), ("identifier", true), ("statement_block", true), @@ -2367,7 +2376,7 @@ fn test_query_is_definite() { (function_declaration name: (identifier) body: (statement_block "{" (expression_statement) "}"))"#, - results_by_symbol: &[ + results_by_substring: &[ ("function_declaration", false), ("identifier", false), ("statement_block", false), @@ -2383,7 +2392,7 @@ fn test_query_is_definite() { value: (constant) "end") "#, - results_by_symbol: &[ + results_by_substring: &[ ("singleton_class", false), ("constant", false), ("end", true), @@ -2397,7 +2406,7 @@ fn test_query_is_definite() { property: (property_identifier) @template-tag) arguments: (template_string)) @template-call "#, - results_by_symbol: &[("property_identifier", false), ("template_string", false)], + results_by_substring: &[("property_identifier", false), ("template_string", false)], }, Row { language: get_language("javascript"), @@ -2408,7 +2417,7 @@ fn test_query_is_definite() { property: (property_identifier) @prop) "[") "#, - results_by_symbol: &[ + results_by_substring: &[ ("identifier", false), ("property_identifier", true), ("[", true), @@ -2424,7 +2433,7 @@ fn test_query_is_definite() { "[" (#match? @prop "foo")) "#, - results_by_symbol: &[ + results_by_substring: &[ ("identifier", false), ("property_identifier", false), ("[", true), @@ -2435,23 +2444,17 @@ fn test_query_is_definite() { allocations::record(|| { for row in rows.iter() { let query = Query::new(row.language, row.pattern).unwrap(); - for (symbol_name, is_definite) in row.results_by_symbol { - let mut symbol = 0; - if !symbol_name.is_empty() { - symbol = row.language.id_for_node_kind(symbol_name, true); - if symbol == 0 { - symbol = row.language.id_for_node_kind(symbol_name, false); - } - } + for (substring, is_definite) in row.results_by_substring { + let offset = row.pattern.find(substring).unwrap(); assert_eq!( - query.pattern_is_definite(0, symbol, 0), + query.step_is_definite(offset), *is_definite, - "Pattern: {:?}, symbol: {}, expected is_definite to be {}", + "Pattern: {:?}, substring: {:?}, expected is_definite to be {}", row.pattern .split_ascii_whitespace() .collect::>() .join(" "), - symbol_name, + substring, is_definite, ) } diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index b5ff7a9e..81cc6f9a 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -651,12 +651,7 @@ extern "C" { ) -> *const TSQueryPredicateStep; } extern "C" { - pub fn ts_query_pattern_is_definite( - self_: *const TSQuery, - pattern_index: u32, - symbol: TSSymbol, - step_index: u32, - ) -> bool; + pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool; } extern "C" { #[doc = " Get the name and length of one of the query\'s captures, or one of the"] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index c601aecc..10cd9fc2 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1467,12 +1467,12 @@ impl Query { unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) } } - /// Check if a pattern will definitely match after a certain number of steps - /// have matched. - pub fn pattern_is_definite(&self, pattern_index: usize, symbol: u16, step_index: usize) -> bool { - unsafe { - ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), pattern_index as u32, symbol, step_index as u32) - } + /// Check if a given step in a query is 'definite'. + /// + /// A query step is 'definite' if its parent pattern will be guaranteed to match + /// successfully once it reaches the step. + pub fn step_is_definite(&self, byte_offset: usize) -> bool { + unsafe { ffi::ts_query_step_is_definite(self.ptr.as_ptr(), byte_offset as u32) } } fn parse_property( diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 850cd31e..1e60e4b5 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -719,11 +719,9 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); -bool ts_query_pattern_is_definite( +bool ts_query_step_is_definite( const TSQuery *self, - uint32_t pattern_index, - TSSymbol symbol, - uint32_t step_index + uint32_t byte_offset ); /** diff --git a/lib/src/query.c b/lib/src/query.c index a156beb9..5a2bb2fb 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -22,6 +22,7 @@ */ typedef struct { const char *input; + const char *start; const char *end; int32_t next; uint8_t next_size; @@ -96,6 +97,11 @@ typedef struct { uint32_t start_byte; } QueryPattern; +typedef struct { + uint32_t byte_offset; + uint16_t step_index; +} StepOffset; + /* * QueryState - The state of an in-progress match of a particular pattern * in a query. While executing, a `TSQueryCursor` must keep track of a number @@ -202,6 +208,7 @@ struct TSQuery { Array(PatternEntry) pattern_map; Array(TSQueryPredicateStep) predicate_steps; Array(QueryPattern) patterns; + Array(StepOffset) step_offsets; const TSLanguage *language; uint16_t wildcard_root_pattern_count; TSSymbol *symbol_map; @@ -268,21 +275,22 @@ static Stream stream_new(const char *string, uint32_t length) { Stream self = { .next = 0, .input = string, + .start = string, .end = string + length, }; stream_advance(&self); return self; } -static void stream_skip_whitespace(Stream *stream) { +static void stream_skip_whitespace(Stream *self) { for (;;) { - if (iswspace(stream->next)) { - stream_advance(stream); - } else if (stream->next == ';') { + if (iswspace(self->next)) { + stream_advance(self); + } else if (self->next == ';') { // skip over comments - stream_advance(stream); - while (stream->next && stream->next != '\n') { - if (!stream_advance(stream)) break; + stream_advance(self); + while (self->next && self->next != '\n') { + if (!stream_advance(self)) break; } } else { break; @@ -290,8 +298,8 @@ static void stream_skip_whitespace(Stream *stream) { } } -static bool stream_is_ident_start(Stream *stream) { - return iswalnum(stream->next) || stream->next == '_' || stream->next == '-'; +static bool stream_is_ident_start(Stream *self) { + return iswalnum(self->next) || self->next == '_' || self->next == '-'; } static void stream_scan_identifier(Stream *stream) { @@ -307,6 +315,10 @@ static void stream_scan_identifier(Stream *stream) { ); } +static uint32_t stream_offset(Stream *self) { + return self->input - self->start; +} + /****************** * CaptureListPool ******************/ @@ -716,7 +728,7 @@ static inline void ts_query__pattern_map_insert( // #define DEBUG_ANALYZE_QUERY -static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index) { +static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // Identify all of the patterns in the query that have child patterns, both at the // top level and nested within other larger patterns. Record the step index where // each pattern starts. @@ -1165,12 +1177,12 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // If this pattern cannot match, store the pattern index so that it can be // returned to the caller. if (result && !can_finish_pattern) { - unsigned exists; - array_search_sorted_by( - &self->patterns, 0, - .steps.offset, parent_step_index, - impossible_index, &exists - ); + assert(final_step_indices.size > 0); + uint16_t *impossible_step_index = array_back(&final_step_indices); + uint32_t i, exists; + array_search_sorted_by(&self->step_offsets, 0, .step_index, *impossible_step_index, &i, &exists); + assert(exists); + *error_offset = self->step_offsets.contents[i].byte_offset; result = false; } } @@ -1415,17 +1427,24 @@ static TSQueryError ts_query__parse_pattern( uint32_t depth, bool is_immediate ) { + if (stream->next == 0) return TSQueryErrorSyntax; + if (stream->next == ')' || stream->next == ']') return PARENT_DONE; + const uint32_t starting_step_index = self->steps.size; - if (stream->next == 0) return TSQueryErrorSyntax; - - // Finish the parent S-expression. - if (stream->next == ')' || stream->next == ']') { - return PARENT_DONE; + // Store the byte offset of each step in the query. + if ( + self->step_offsets.size == 0 || + array_back(&self->step_offsets)->step_index != starting_step_index + ) { + array_push(&self->step_offsets, ((StepOffset) { + .step_index = starting_step_index, + .byte_offset = stream_offset(stream), + })); } // An open bracket is the start of an alternation. - else if (stream->next == '[') { + if (stream->next == '[') { stream_advance(stream); stream_skip_whitespace(stream); @@ -1818,6 +1837,7 @@ TSQuery *ts_query_new( .predicate_values = symbol_table_new(), .predicate_steps = array_new(), .patterns = array_new(), + .step_offsets = array_new(), .symbol_map = symbol_map, .wildcard_root_pattern_count = 0, .language = language, @@ -1833,7 +1853,7 @@ TSQuery *ts_query_new( array_push(&self->patterns, ((QueryPattern) { .steps = (Slice) {.offset = start_step_index}, .predicate_steps = (Slice) {.offset = start_predicate_step_index}, - .start_byte = stream.input - source, + .start_byte = stream_offset(&stream), })); *error_type = ts_query__parse_pattern(self, &stream, 0, false); array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER, false)); @@ -1846,7 +1866,7 @@ TSQuery *ts_query_new( // and terminate. if (*error_type) { if (*error_type == PARENT_DONE) *error_type = TSQueryErrorSyntax; - *error_offset = stream.input - source; + *error_offset = stream_offset(&stream); ts_query_delete(self); return NULL; } @@ -1882,10 +1902,8 @@ TSQuery *ts_query_new( } if (self->language->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT) { - unsigned impossible_pattern_index = 0; - if (!ts_query__analyze_patterns(self, &impossible_pattern_index)) { + if (!ts_query__analyze_patterns(self, error_offset)) { *error_type = TSQueryErrorPattern; - *error_offset = self->patterns.contents[impossible_pattern_index].start_byte; ts_query_delete(self); return NULL; } @@ -1901,6 +1919,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->pattern_map); array_delete(&self->predicate_steps); array_delete(&self->patterns); + array_delete(&self->step_offsets); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); ts_free(self->symbol_map); @@ -1953,24 +1972,21 @@ uint32_t ts_query_start_byte_for_pattern( return self->patterns.contents[pattern_index].start_byte; } -bool ts_query_pattern_is_definite( +bool ts_query_step_is_definite( const TSQuery *self, - uint32_t pattern_index, - TSSymbol symbol, - uint32_t index + uint32_t byte_offset ) { - uint32_t step_index = self->patterns.contents[pattern_index].steps.offset; - QueryStep *step = &self->steps.contents[step_index]; - for (; step->depth != PATTERN_DONE_MARKER; step++) { - bool does_match = symbol ? - step->symbol == symbol : - step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL; - if (does_match) { - if (index == 0) return step->is_definite; - index--; - } + uint32_t step_index = UINT32_MAX; + for (unsigned i = 0; i < self->step_offsets.size; i++) { + StepOffset *step_offset = &self->step_offsets.contents[i]; + if (step_offset->byte_offset >= byte_offset) break; + step_index = step_offset->step_index; + } + if (step_index < self->steps.size) { + return self->steps.contents[step_index].is_definite; + } else { + return false; } - return false; } void ts_query_disable_capture(