diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 9206cc05..70808a30 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -596,6 +596,41 @@ fn test_query_captures_with_text_conditions() { }); } +#[test] +fn test_query_pattern_after_source_byte() { + let language = get_language("javascript"); + + let patterns_1 = r#" + "+" @operator + "-" @operator + "*" @operator + "=" @operator + "=>" @operator + "#.trim_start(); + + let patterns_2 = " + (identifier) @a + (string) @b + ".trim_start(); + + let patterns_3 = " + ((identifier) @b (match? @b i)) + (function_declaration name: (identifier) @c) + (method_definition name: (identifier) @d) + ".trim_start(); + + let mut source = String::new(); + source += patterns_1; + source += patterns_2; + source += patterns_3; + + let query = Query::new(language, &source).unwrap(); + + assert_eq!(query.start_byte_for_pattern(0), 0); + assert_eq!(query.start_byte_for_pattern(5), patterns_1.len()); + assert_eq!(query.start_byte_for_pattern(7), patterns_1.len() + patterns_2.len()); +} + #[test] fn test_query_capture_names() { allocations::record(|| { diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 1be6472d..210a6f57 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -601,25 +601,43 @@ extern "C" { pub fn ts_query_delete(arg1: *mut TSQuery); } extern "C" { - #[doc = " Get the number of patterns in the query."] + #[doc = " Get the number of patterns, captures, or string literals in the query."] pub fn ts_query_pattern_count(arg1: *const TSQuery) -> u32; } extern "C" { - #[doc = " Get the predicates for the given pattern in the query."] + pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32; +} +extern "C" { + pub fn ts_query_string_count(arg1: *const TSQuery) -> u32; +} +extern "C" { + #[doc = " Get the byte offset where the given pattern starts in the query\'s source."] + #[doc = ""] + #[doc = " This can be useful when combining queries by concatenating their source"] + #[doc = " code strings."] + pub fn ts_query_start_byte_for_pattern(arg1: *const TSQuery, arg2: u32) -> u32; +} +extern "C" { + #[doc = " Get all of the predicates for the given pattern in the query."] + #[doc = ""] + #[doc = " The predicates are represented as a single array of steps. There are three"] + #[doc = " types of steps in this array, which correspond to the three legal values for"] + #[doc = " the `type` field:"] + #[doc = " - `TSQueryPredicateStepTypeCapture` - Steps with this type represent names"] + #[doc = " of captures. Their `value_id` can be used with the"] + #[doc = " `ts_query_capture_name_for_id` function to obtain the name of the capture."] + #[doc = " - `TSQueryPredicateStepTypeString` - Steps with this type represent literal"] + #[doc = " strings. Their `value_id` can be used with the"] + #[doc = " `ts_query_string_value_for_id` function to obtain their string value."] + #[doc = " - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels*"] + #[doc = " that represent the end of an individual predicate. If a pattern has two"] + #[doc = " predicates, then there will be two steps with this `type` in the array."] pub fn ts_query_predicates_for_pattern( self_: *const TSQuery, pattern_index: u32, length: *mut u32, ) -> *const TSQueryPredicateStep; } -extern "C" { - #[doc = " Get the number of distinct capture names in the query, or the number of"] - #[doc = " distinct string literals in the query."] - pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32; -} -extern "C" { - pub fn ts_query_string_count(arg1: *const TSQuery) -> u32; -} extern "C" { #[doc = " Get the name and length of one of the query\'s captures, or one of the"] #[doc = " query\'s string literals. Each capture and string is associated with a"] @@ -637,22 +655,6 @@ extern "C" { length: *mut u32, ) -> *const ::std::os::raw::c_char; } -extern "C" { - #[doc = " Get the numeric id of the capture with the given name, or string with the"] - #[doc = " given value."] - pub fn ts_query_capture_id_for_name( - self_: *const TSQuery, - name: *const ::std::os::raw::c_char, - length: u32, - ) -> ::std::os::raw::c_int; -} -extern "C" { - pub fn ts_query_string_id_for_value( - self_: *const TSQuery, - value: *const ::std::os::raw::c_char, - length: u32, - ) -> ::std::os::raw::c_int; -} extern "C" { #[doc = " Create a new cursor for executing a given query."] #[doc = ""] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 2e0cfa0e..ee457b5f 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1124,6 +1124,17 @@ impl Query { Ok(result) } + pub fn start_byte_for_pattern(&self, pattern_index: usize) -> usize { + if pattern_index >= self.predicates.len() { + panic!( + "Pattern index is {} but the pattern count is {}", + pattern_index, + self.predicates.len(), + ); + } + unsafe { ffi::ts_query_start_byte_for_pattern(self.ptr, pattern_index as u32) as usize } + } + pub fn capture_names(&self) -> &[String] { &self.capture_names } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 70215e36..f04370e4 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -670,6 +670,14 @@ uint32_t ts_query_pattern_count(const TSQuery *); uint32_t ts_query_capture_count(const TSQuery *); uint32_t ts_query_string_count(const TSQuery *); +/** + * Get the byte offset where the given pattern starts in the query's source. + * + * This can be useful when combining queries by concatenating their source + * code strings. + */ +uint32_t ts_query_start_byte_for_pattern(const TSQuery *, uint32_t); + /** * Get all of the predicates for the given pattern in the query. * diff --git a/lib/src/query.c b/lib/src/query.c index 6be60956..975e3cef 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -102,12 +102,13 @@ typedef struct { * query is stored in a `TSQueryCursor`. */ struct TSQuery { - Array(QueryStep) steps; SymbolTable captures; SymbolTable predicate_values; + Array(QueryStep) steps; Array(PatternEntry) pattern_map; Array(TSQueryPredicateStep) predicate_steps; Array(Slice) predicates_by_pattern; + Array(uint32_t) start_bytes_by_pattern; const TSLanguage *language; uint16_t max_capture_count; uint16_t wildcard_root_pattern_count; @@ -743,6 +744,7 @@ TSQuery *ts_query_new( for (;;) { start_step_index = self->steps.size; uint32_t capture_count = 0; + array_push(&self->start_bytes_by_pattern, stream.input - source); array_push(&self->predicates_by_pattern, ((Slice) { .offset = self->predicate_steps.size, .length = 0, @@ -787,6 +789,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->pattern_map); array_delete(&self->predicate_steps); array_delete(&self->predicates_by_pattern); + array_delete(&self->start_bytes_by_pattern); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); ts_free(self); @@ -831,6 +834,13 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( return &self->predicate_steps.contents[slice.offset]; } +uint32_t ts_query_start_byte_for_pattern( + const TSQuery *self, + uint32_t pattern_index +) { + return self->start_bytes_by_pattern.contents[pattern_index]; +} + /*************** * QueryCursor ***************/