diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 9d398f58..5da90d92 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -107,6 +107,11 @@ pub struct TSQueryCapture { pub node: TSNode, pub index: u32, } +pub const TSQuantifier_One: TSQuantifier = 0; +pub const TSQuantifier_OneOrMore: TSQuantifier = 1; +pub const TSQuantifier_ZeroOrOne: TSQuantifier = 2; +pub const TSQuantifier_ZeroOrMore: TSQuantifier = 3; +pub type TSQuantifier = u32; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSQueryMatch { @@ -666,13 +671,10 @@ extern "C" { ) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the suffix of one of the query's captures, or one of the query's"] - #[doc = " string literals. Each capture and string is associated with a numeric"] - #[doc = " id based on the order that it appeared in the query's source."] - pub fn ts_query_capture_suffix_for_id( - arg1: *const TSQuery, - id: u32, - ) -> ::std::os::raw::c_char; + #[doc = " Get the quantifier of the query's captures, or one of the query's string"] + #[doc = " literals. Each capture and string is associated with a numeric id based"] + #[doc = " on the order that it appeared in the query's source."] + pub fn ts_query_capture_quantifier_for_id(arg1: *const TSQuery, id: u32) -> TSQuantifier; } extern "C" { pub fn ts_query_string_value_for_id( diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 429cd47c..952d8864 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -98,30 +98,30 @@ pub struct TreeCursor<'a>(ffi::TSTreeCursor, PhantomData<&'a ()>); pub struct Query { ptr: NonNull, capture_names: Vec, - capture_suffixes: Vec, + capture_quantifiers: Vec, text_predicates: Vec>, property_settings: Vec>, property_predicates: Vec>, general_predicates: Vec>, } -/// A suffix indicating the multiplicity of the capture value +/// A quantifier for captures #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum QueryCaptureSuffix { +pub enum Quantifier { One, OneOrMore, - ZeroOrMore, ZeroOrOne, + ZeroOrMore, } -impl From for QueryCaptureSuffix { - fn from(value: u8) -> QueryCaptureSuffix { +impl From for Quantifier { + fn from(value: ffi::TSQuantifier) -> Self { match value { - b'\0' => QueryCaptureSuffix::One, - b'+' => QueryCaptureSuffix::OneOrMore, - b'*' => QueryCaptureSuffix::ZeroOrMore, - b'?' => QueryCaptureSuffix::ZeroOrOne, - _ => panic!("Unrecognized suffix: {}", value as char), + ffi::TSQuantifier_One => Quantifier::One, + ffi::TSQuantifier_OneOrMore => Quantifier::OneOrMore, + ffi::TSQuantifier_ZeroOrOne => Quantifier::ZeroOrOne, + ffi::TSQuantifier_ZeroOrMore => Quantifier::ZeroOrMore, + _ => panic!("Unrecognized quantifier: {}", value), } } } @@ -1328,7 +1328,7 @@ impl Query { let mut result = Query { ptr: unsafe { NonNull::new_unchecked(ptr) }, capture_names: Vec::with_capacity(capture_count as usize), - capture_suffixes: Vec::with_capacity(capture_count as usize), + capture_quantifiers: Vec::with_capacity(capture_count as usize), text_predicates: Vec::with_capacity(pattern_count), property_predicates: Vec::with_capacity(pattern_count), property_settings: Vec::with_capacity(pattern_count), @@ -1344,8 +1344,8 @@ impl Query { let name = slice::from_raw_parts(name, length as usize); let name = str::from_utf8_unchecked(name); result.capture_names.push(name.to_string()); - let suffix = ffi::ts_query_capture_suffix_for_id(ptr, i) as u8; - result.capture_suffixes.push(suffix.into()); + let quantifier = ffi::ts_query_capture_quantifier_for_id(ptr, i); + result.capture_quantifiers.push(quantifier.into()); } } @@ -1549,9 +1549,9 @@ impl Query { &self.capture_names } - /// Get the suffixes of the captures used in the query. - pub fn capture_suffixes(&self) -> &[QueryCaptureSuffix] { - &self.capture_suffixes + /// Get the quantifiers of the captures used in the query. + pub fn capture_quantifiers(&self) -> &[Quantifier] { + &self.capture_quantifiers } /// Get the index for a given capture name. diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 68b67668..b4f77f46 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -106,6 +106,13 @@ typedef struct { uint32_t index; } TSQueryCapture; +typedef enum { + One, + OneOrMore, + ZeroOrOne, + ZeroOrMore, +} TSQuantifier; + typedef struct { uint32_t id; uint16_t pattern_index; @@ -740,7 +747,13 @@ const char *ts_query_capture_name_for_id( uint32_t id, uint32_t *length ); -char ts_query_capture_suffix_for_id( + +/** + * Get the quantifier of the query's captures, or one of the query's string + * literals. Each capture and string is associated with a numeric id based + * on the order that it appeared in the query's source. + */ +TSQuantifier ts_query_capture_quantifier_for_id( const TSQuery *, uint32_t id ); diff --git a/lib/src/query.c b/lib/src/query.c index 734d9bcf..83c6b297 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -39,9 +39,6 @@ typedef struct { * was not specified. * - `capture_ids` - An array of integers representing the names of captures * associated with this node in the pattern, terminated by a `NONE` value. - * - `capture_suffixes` - An array of capture suffixes ('\0\, '+', '*', or '?') - * corresponding to the elements in `capture_ids`, terminated by a `NONE` - * value. * - `depth` - The depth where this node occurs in the pattern. The root node * of the pattern has depth zero. * - `negated_field_list_id` - An id representing a set of fields that must @@ -75,7 +72,7 @@ typedef struct { * Steps also store some derived state that summarizes how they relate to other * steps within the same pattern. This is used to optimize the matching process: * - `contains_captures` - Indicates that this step or one of its child steps - * has non-empty `capture_ids` and `capture_suffixes` lists. + * has a non-empty `capture_ids` list. * - `parent_pattern_guaranteed` - Indicates that if this step is reached, then * it and all of its subsequent sibling steps within the same parent pattern * are guaranteed to match. @@ -90,7 +87,6 @@ typedef struct { TSSymbol supertype_symbol; TSFieldId field; uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT]; - uint16_t capture_suffixes[MAX_STEP_CAPTURE_COUNT]; uint16_t depth; uint16_t alternative_index; uint16_t negated_field_list_id; @@ -122,7 +118,7 @@ typedef struct { typedef struct { Array(char) characters; Array(Slice) slices; - Array(char) suffixes; + Array(TSQuantifier) quantifiers; } SymbolTable; /* @@ -458,6 +454,43 @@ static void capture_list_pool_release(CaptureListPool *self, uint16_t id) { self->free_capture_list_count++; } +/************** + * Quantifiers + **************/ + +static TSQuantifier quantifier_join( + TSQuantifier left, + TSQuantifier right +) { + switch (left) + { + case One: + return right; + case OneOrMore: + switch (right) { + case One: + case OneOrMore: + return OneOrMore; + case ZeroOrOne: + case ZeroOrMore: + return ZeroOrMore; + }; + break; + case ZeroOrOne: + switch (right) { + case One: + case ZeroOrOne: + return ZeroOrOne; + case OneOrMore: + case ZeroOrMore: + return ZeroOrMore; + }; + break; + case ZeroOrMore: + return ZeroOrMore; + } +} + /************** * SymbolTable **************/ @@ -466,14 +499,14 @@ static SymbolTable symbol_table_new(void) { return (SymbolTable) { .characters = array_new(), .slices = array_new(), - .suffixes = array_new(), + .quantifiers = array_new(), }; } static void symbol_table_delete(SymbolTable *self) { array_delete(&self->characters); array_delete(&self->slices); - array_delete(&self->suffixes); + array_delete(&self->quantifiers); } static int symbol_table_id_for_name( @@ -501,18 +534,17 @@ static const char *symbol_table_name_for_id( return &self->characters.contents[slice.offset]; } -static char symbol_table_suffix_for_id( +static TSQuantifier symbol_table_quantifier_for_id( const SymbolTable *self, uint16_t id ) { - return self->suffixes.contents[id]; + return self->quantifiers.contents[id]; } static uint16_t symbol_table_insert_name( SymbolTable *self, const char *name, - uint32_t length, - char suffix + uint32_t length ) { int id = symbol_table_id_for_name(self, name, length); if (id >= 0) return (uint16_t)id; @@ -524,10 +556,22 @@ static uint16_t symbol_table_insert_name( memcpy(&self->characters.contents[slice.offset], name, length); self->characters.contents[self->characters.size - 1] = 0; array_push(&self->slices, slice); - array_push(&self->suffixes, suffix); + array_push(&self->quantifiers, One); return self->slices.size - 1; } +static void symbol_table_quantifiers_join( + SymbolTable *self, + TSQuantifier quantifier, + uint32_t start_index, + uint32_t end_index +) { + for (uint32_t index = start_index; index < end_index; index++) { + TSQuantifier *joined_quantifier = &self->quantifiers.contents[index]; + *joined_quantifier = quantifier_join(quantifier, *joined_quantifier); + } +} + /************ * QueryStep ************/ @@ -542,7 +586,6 @@ static QueryStep query_step__new( .depth = depth, .field = 0, .capture_ids = {NONE, NONE, NONE}, - .capture_suffixes = {NONE, NONE, NONE}, .alternative_index = NONE, .negated_field_list_id = 0, .contains_captures = false, @@ -556,11 +599,10 @@ static QueryStep query_step__new( }; } -static void query_step__add_capture(QueryStep *self, uint16_t capture_id, char capture_suffix) { +static void query_step__add_capture(QueryStep *self, uint16_t capture_id) { for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) { if (self->capture_ids[i] == NONE) { self->capture_ids[i] = capture_id; - self->capture_suffixes[i] = capture_suffix; break; } } @@ -570,13 +612,10 @@ static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) { for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) { if (self->capture_ids[i] == capture_id) { self->capture_ids[i] = NONE; - self->capture_suffixes[i] = NONE; while (i + 1 < MAX_STEP_CAPTURE_COUNT) { if (self->capture_ids[i + 1] == NONE) break; self->capture_ids[i] = self->capture_ids[i + 1]; - self->capture_suffixes[i] = self->capture_suffixes[i + 1]; self->capture_ids[i + 1] = NONE; - self->capture_suffixes[i + 1] = NONE; i++; } break; @@ -1611,8 +1650,7 @@ static TSQueryError ts_query__parse_predicate( uint16_t id = symbol_table_insert_name( &self->predicate_values, predicate_name, - length, - '\0' + length ); array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, @@ -1665,8 +1703,7 @@ static TSQueryError ts_query__parse_predicate( uint16_t id = symbol_table_insert_name( &self->predicate_values, self->string_buffer.contents, - self->string_buffer.size, - '\0' + self->string_buffer.size ); array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, @@ -1682,8 +1719,7 @@ static TSQueryError ts_query__parse_predicate( uint16_t id = symbol_table_insert_name( &self->predicate_values, symbol_start, - length, - '\0' + length ); array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, @@ -1714,6 +1750,7 @@ static TSQueryError ts_query__parse_pattern( if (stream->next == ')' || stream->next == ']') return PARENT_DONE; const uint32_t starting_step_index = self->steps.size; + const uint32_t starting_quantifier_index = self->captures.quantifiers.size; // Store the byte offset of each step in the query. if ( @@ -1771,6 +1808,16 @@ static TSQueryError ts_query__parse_pattern( end_step->is_dead_end = true; } + if (branch_step_indices.size > 1) { + const uint32_t ending_quantifier_index = self->captures.quantifiers.size; + symbol_table_quantifiers_join( + &self->captures, + ZeroOrOne, + starting_quantifier_index, + ending_quantifier_index + ); + } + array_delete(&branch_step_indices); } @@ -2058,11 +2105,11 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); // Parse suffixes modifiers for this pattern - char suffix = '\0'; + TSQuantifier quantifier = One; for (;;) { // Parse the one-or-more operator. if (stream->next == '+') { - suffix = '+'; + quantifier = quantifier_join(OneOrMore, quantifier); stream_advance(stream); stream_skip_whitespace(stream); @@ -2076,7 +2123,7 @@ static TSQueryError ts_query__parse_pattern( // Parse the zero-or-more repetition operator. else if (stream->next == '*') { - suffix = '*'; + quantifier = quantifier_join(ZeroOrMore, quantifier); stream_advance(stream); stream_skip_whitespace(stream); @@ -2096,7 +2143,7 @@ static TSQueryError ts_query__parse_pattern( // Parse the optional operator. else if (stream->next == '?') { - suffix = '?'; + quantifier = quantifier_join(ZeroOrOne, quantifier); stream_advance(stream); stream_skip_whitespace(stream); @@ -2121,14 +2168,13 @@ static TSQueryError ts_query__parse_pattern( uint16_t capture_id = symbol_table_insert_name( &self->captures, capture_name, - length, - suffix + length ); uint32_t step_index = starting_step_index; for (;;) { QueryStep *step = &self->steps.contents[step_index]; - query_step__add_capture(step, capture_id, suffix); + query_step__add_capture(step, capture_id); if ( step->alternative_index != NONE && step->alternative_index > step_index && @@ -2148,6 +2194,17 @@ static TSQueryError ts_query__parse_pattern( } } + // Patch capture quantifiers + if (quantifier != One) { + const uint32_t ending_quantifier_index = self->captures.quantifiers.size; + symbol_table_quantifiers_join( + &self->captures, + quantifier, + starting_quantifier_index, + ending_quantifier_index + ); + } + return 0; } @@ -2311,11 +2368,11 @@ const char *ts_query_capture_name_for_id( return symbol_table_name_for_id(&self->captures, index, length); } -char ts_query_capture_suffix_for_id( +TSQuantifier ts_query_capture_quantifier_for_id( const TSQuery *self, uint32_t index ) { - return symbol_table_suffix_for_id(&self->captures, index); + return symbol_table_quantifier_for_id(&self->captures, index); } const char *ts_query_string_value_for_id(