diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 1703e610..6634c62f 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -7,8 +7,8 @@ use lazy_static::lazy_static; use rand::{prelude::StdRng, SeedableRng}; use std::{env, fmt::Write}; use tree_sitter::{ - Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryError, QueryErrorKind, - QueryMatch, QueryPredicate, QueryPredicateArg, QueryProperty, + CaptureQuantifier, Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryError, + QueryErrorKind, QueryMatch, QueryPredicate, QueryPredicateArg, QueryProperty, }; lazy_static! { @@ -3818,6 +3818,243 @@ fn test_query_is_pattern_guaranteed_at_step() { }); } +#[test] +fn test_capture_quantifiers() { + struct Row { + description: &'static str, + language: Language, + pattern: &'static str, + capture_quantifiers: &'static [(usize, &'static str, CaptureQuantifier)], + } + + let rows = &[ + // Simple quantifiers + Row { + description: "Top level capture", + language: get_language("python"), + pattern: r#" + (module) @mod + "#, + capture_quantifiers: &[(0, "mod", CaptureQuantifier::One)], + }, + Row { + description: "Nested list capture capture", + language: get_language("javascript"), + pattern: r#" + (array (_)* @elems) @array + "#, + capture_quantifiers: &[ + (0, "array", CaptureQuantifier::One), + (0, "elems", CaptureQuantifier::ZeroOrMore), + ], + }, + Row { + description: "Nested non-empty list capture capture", + language: get_language("javascript"), + pattern: r#" + (array (_)+ @elems) @array + "#, + capture_quantifiers: &[ + (0, "array", CaptureQuantifier::One), + (0, "elems", CaptureQuantifier::OneOrMore), + ], + }, + // Nested quantifiers + Row { + description: "capture nested in optional pattern", + language: get_language("javascript"), + pattern: r#" + (array (call_expression (arguments (_) @arg))? @call) @array + "#, + capture_quantifiers: &[ + (0, "array", CaptureQuantifier::One), + (0, "call", CaptureQuantifier::ZeroOrOne), + (0, "arg", CaptureQuantifier::ZeroOrOne), + ], + }, + Row { + description: "optional capture nested in non-empty list pattern", + language: get_language("javascript"), + pattern: r#" + (array (call_expression (arguments (_)? @arg))+ @call) @array + "#, + capture_quantifiers: &[ + (0, "array", CaptureQuantifier::One), + (0, "call", CaptureQuantifier::OneOrMore), + (0, "arg", CaptureQuantifier::ZeroOrMore), + ], + }, + Row { + description: "non-empty list capture nested in optional pattern", + language: get_language("javascript"), + pattern: r#" + (array (call_expression (arguments (_)+ @args))? @call) @array + "#, + capture_quantifiers: &[ + (0, "array", CaptureQuantifier::One), + (0, "call", CaptureQuantifier::ZeroOrOne), + (0, "args", CaptureQuantifier::ZeroOrMore), + ], + }, + // Quantifiers in alternations + Row { + description: "capture is the same in all alternatives", + language: get_language("javascript"), + pattern: r#"[ + (function_declaration name:(identifier) @name) + (call_expression function:(identifier) @name) + ]"#, + capture_quantifiers: &[(0, "name", CaptureQuantifier::One)], + }, + Row { + description: "capture appears in some alternatives", + language: get_language("javascript"), + pattern: r#"[ + (function_declaration name:(identifier) @name) + (function) + ] @fun"#, + capture_quantifiers: &[ + (0, "fun", CaptureQuantifier::One), + (0, "name", CaptureQuantifier::ZeroOrOne), + ], + }, + Row { + description: "capture has different quantifiers in alternatives", + language: get_language("javascript"), + pattern: r#"[ + (call_expression arguments:(arguments (_)+ @args)) + (new_expression arguments:(arguments (_)? @args)) + ] @call"#, + capture_quantifiers: &[ + (0, "call", CaptureQuantifier::One), + (0, "args", CaptureQuantifier::ZeroOrMore), + ], + }, + // Quantifiers in siblings + Row { + description: "siblings have different captures with different quantifiers", + language: get_language("javascript"), + pattern: r#" + (call_expression (arguments (identifier)? @self (_)* @args)) @call + "#, + capture_quantifiers: &[ + (0, "call", CaptureQuantifier::One), + (0, "self", CaptureQuantifier::ZeroOrOne), + (0, "args", CaptureQuantifier::ZeroOrMore), + ], + }, + Row { + description: "siblings have same capture with different quantifiers", + language: get_language("javascript"), + pattern: r#" + (call_expression (arguments (identifier) @args (_)* @args)) @call + "#, + capture_quantifiers: &[ + (0, "call", CaptureQuantifier::One), + (0, "args", CaptureQuantifier::OneOrMore), + ], + }, + // Combined scenarios + Row { + description: "combined nesting, alternatives, and siblings", + language: get_language("javascript"), + pattern: r#" + (array + (call_expression + (arguments [ + (identifier) @self + (_)+ @args + ]) + )+ @call + ) @array + "#, + capture_quantifiers: &[ + (0, "array", CaptureQuantifier::One), + (0, "call", CaptureQuantifier::OneOrMore), + (0, "self", CaptureQuantifier::ZeroOrMore), + (0, "args", CaptureQuantifier::ZeroOrMore), + ], + }, + // Multiple patterns + Row { + description: "multiple patterns", + language: get_language("javascript"), + pattern: r#" + (function_declaration name: (identifier) @x) + (statement_identifier) @y + (property_identifier)+ @z + (array (identifier)* @x) + "#, + capture_quantifiers: &[ + // x + (0, "x", CaptureQuantifier::One), + (1, "x", CaptureQuantifier::Zero), + (2, "x", CaptureQuantifier::Zero), + (3, "x", CaptureQuantifier::ZeroOrMore), + // y + (0, "y", CaptureQuantifier::Zero), + (1, "y", CaptureQuantifier::One), + (2, "y", CaptureQuantifier::Zero), + (3, "y", CaptureQuantifier::Zero), + // z + (0, "z", CaptureQuantifier::Zero), + (1, "z", CaptureQuantifier::Zero), + (2, "z", CaptureQuantifier::OneOrMore), + (3, "z", CaptureQuantifier::Zero), + ], + }, + Row { + description: "multiple alternatives", + language: get_language("javascript"), + pattern: r#" + [ + (array (identifier) @x) + (function_declaration name: (identifier)+ @x) + ] + [ + (array (identifier) @x) + (function_declaration name: (identifier)+ @x) + ] + "#, + capture_quantifiers: &[ + (0, "x", CaptureQuantifier::OneOrMore), + (1, "x", CaptureQuantifier::OneOrMore), + ], + }, + ]; + + allocations::record(|| { + eprintln!(""); + + for row in rows.iter() { + if let Some(filter) = EXAMPLE_FILTER.as_ref() { + if !row.description.contains(filter.as_str()) { + continue; + } + } + eprintln!(" query example: {:?}", row.description); + let query = Query::new(row.language, row.pattern).unwrap(); + for (pattern, capture, expected_quantifier) in row.capture_quantifiers { + let index = query.capture_index_for_name(capture).unwrap(); + let actual_quantifier = query.capture_quantifiers(*pattern)[index as usize]; + assert_eq!( + actual_quantifier, + *expected_quantifier, + "Description: {}, Pattern: {:?}, expected quantifier of @{} to be {:?} instead of {:?}", + row.description, + row.pattern + .split_ascii_whitespace() + .collect::>() + .join(" "), + capture, + *expected_quantifier, + actual_quantifier, + ) + } + } + }); +} + fn assert_query_matches( language: Language, query: &Query, diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 5bcbac42..a79f432e 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -107,6 +107,12 @@ pub struct TSQueryCapture { pub node: TSNode, pub index: u32, } +pub const TSQuantifier_TSQuantifierZero: TSQuantifier = 0; +pub const TSQuantifier_TSQuantifierZeroOrOne: TSQuantifier = 1; +pub const TSQuantifier_TSQuantifierZeroOrMore: TSQuantifier = 2; +pub const TSQuantifier_TSQuantifierOne: TSQuantifier = 3; +pub const TSQuantifier_TSQuantifierOneOrMore: TSQuantifier = 4; +pub type TSQuantifier = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSQueryMatch { @@ -665,6 +671,15 @@ extern "C" { length: *mut u32, ) -> *const ::std::os::raw::c_char; } +extern "C" { + #[doc = " Get the quantifier of the query's captures. Each capture is * associated"] + #[doc = " with a numeric id based on the order that it appeared in the query's source."] + pub fn ts_query_capture_quantifier_for_id( + arg1: *const TSQuery, + pattern_id: u32, + capture_id: u32, + ) -> TSQuantifier; +} extern "C" { pub fn ts_query_string_value_for_id( arg1: *const TSQuery, diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index cf8437b8..e88a411c 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -98,12 +98,36 @@ pub struct TreeCursor<'a>(ffi::TSTreeCursor, PhantomData<&'a ()>); pub struct Query { ptr: NonNull, capture_names: Vec, + capture_quantifiers: Vec>, text_predicates: Vec>, property_settings: Vec>, property_predicates: Vec>, general_predicates: Vec>, } +/// A quantifier for captures +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum CaptureQuantifier { + Zero, + ZeroOrOne, + ZeroOrMore, + One, + OneOrMore, +} + +impl From for CaptureQuantifier { + fn from(value: ffi::TSQuantifier) -> Self { + match value { + ffi::TSQuantifier_TSQuantifierZero => CaptureQuantifier::Zero, + ffi::TSQuantifier_TSQuantifierZeroOrOne => CaptureQuantifier::ZeroOrOne, + ffi::TSQuantifier_TSQuantifierZeroOrMore => CaptureQuantifier::ZeroOrMore, + ffi::TSQuantifier_TSQuantifierOne => CaptureQuantifier::One, + ffi::TSQuantifier_TSQuantifierOneOrMore => CaptureQuantifier::OneOrMore, + _ => panic!("Unrecognized quantifier: {}", value), + } + } +} + /// A stateful object for executing a `Query` on a syntax `Tree`. pub struct QueryCursor { ptr: NonNull, @@ -1306,6 +1330,7 @@ impl Query { let mut result = Query { ptr: unsafe { NonNull::new_unchecked(ptr) }, capture_names: Vec::with_capacity(capture_count as usize), + capture_quantifiers: Vec::with_capacity(pattern_count as usize), text_predicates: Vec::with_capacity(pattern_count), property_predicates: Vec::with_capacity(pattern_count), property_settings: Vec::with_capacity(pattern_count), @@ -1324,6 +1349,18 @@ impl Query { } } + // Build a vector to store capture qunatifiers. + for i in 0..pattern_count { + let mut capture_quantifiers = Vec::with_capacity(capture_count as usize); + for j in 0..capture_count { + unsafe { + let quantifier = ffi::ts_query_capture_quantifier_for_id(ptr, i as u32, j); + capture_quantifiers.push(quantifier.into()); + } + } + result.capture_quantifiers.push(capture_quantifiers); + } + // Build a vector of strings to represent literal values used in predicates. let string_values = (0..string_count) .map(|i| unsafe { @@ -1524,6 +1561,11 @@ impl Query { &self.capture_names } + /// Get the quantifiers of the captures used in the query. + pub fn capture_quantifiers(&self, index: usize) -> &[CaptureQuantifier] { + &self.capture_quantifiers[index] + } + /// Get the index for a given capture name. pub fn capture_index_for_name(&self, name: &str) -> Option { self.capture_names diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 33b8c8f4..7266fba7 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -106,6 +106,14 @@ typedef struct { uint32_t index; } TSQueryCapture; +typedef enum { + TSQuantifierZero = 0, // must match the array initialization value + TSQuantifierZeroOrOne, + TSQuantifierZeroOrMore, + TSQuantifierOne, + TSQuantifierOneOrMore, +} TSQuantifier; + typedef struct { uint32_t id; uint16_t pattern_index; @@ -740,6 +748,17 @@ const char *ts_query_capture_name_for_id( uint32_t id, uint32_t *length ); + +/** + * Get the quantifier of the query's captures. Each capture is * associated + * with a numeric id based on the order that it appeared in the query's source. + */ +TSQuantifier ts_query_capture_quantifier_for_id( + const TSQuery *, + uint32_t pattern_id, + uint32_t capture_id +); + const char *ts_query_string_value_for_id( const TSQuery *, uint32_t id, diff --git a/lib/src/query.c b/lib/src/query.c index aba9a0f3..470bcfd2 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -120,6 +120,11 @@ typedef struct { Array(Slice) slices; } SymbolTable; +/** + * CaptureQuantififers - a data structure holding the quantifiers of pattern captures. + */ +typedef Array(uint8_t) CaptureQuantifiers; + /* * PatternEntry - Information about the starting point for matching a particular * pattern. These entries are stored in a 'pattern map' - a sorted array that @@ -264,6 +269,7 @@ typedef struct { */ struct TSQuery { SymbolTable captures; + Array(CaptureQuantifiers) capture_quantifiers; SymbolTable predicate_values; Array(QueryStep) steps; Array(PatternEntry) pattern_map; @@ -455,6 +461,263 @@ static void capture_list_pool_release(CaptureListPool *self, uint16_t id) { self->free_capture_list_count++; } +/************** + * Quantifiers + **************/ + +static TSQuantifier quantifier_mul( + TSQuantifier left, + TSQuantifier right +) { + switch (left) + { + case TSQuantifierZero: + return TSQuantifierZero; + case TSQuantifierZeroOrOne: + switch (right) { + case TSQuantifierZero: + return TSQuantifierZero; + case TSQuantifierZeroOrOne: + case TSQuantifierOne: + return TSQuantifierZeroOrOne; + case TSQuantifierZeroOrMore: + case TSQuantifierOneOrMore: + return TSQuantifierZeroOrMore; + }; + break; + case TSQuantifierZeroOrMore: + switch (right) { + case TSQuantifierZero: + return TSQuantifierZero; + case TSQuantifierZeroOrOne: + case TSQuantifierZeroOrMore: + case TSQuantifierOne: + case TSQuantifierOneOrMore: + return TSQuantifierZeroOrMore; + }; + break; + case TSQuantifierOne: + return right; + case TSQuantifierOneOrMore: + switch (right) { + case TSQuantifierZero: + return TSQuantifierZero; + case TSQuantifierZeroOrOne: + case TSQuantifierZeroOrMore: + return TSQuantifierZeroOrMore; + case TSQuantifierOne: + case TSQuantifierOneOrMore: + return TSQuantifierOneOrMore; + }; + break; + } + return TSQuantifierZero; // to make compiler happy, but all cases should be covered above! +} + +static TSQuantifier quantifier_join( + TSQuantifier left, + TSQuantifier right +) { + switch (left) + { + case TSQuantifierZero: + switch (right) { + case TSQuantifierZero: + return TSQuantifierZero; + case TSQuantifierZeroOrOne: + case TSQuantifierOne: + return TSQuantifierZeroOrOne; + case TSQuantifierZeroOrMore: + case TSQuantifierOneOrMore: + return TSQuantifierZeroOrMore; + }; + break; + case TSQuantifierZeroOrOne: + switch (right) { + case TSQuantifierZero: + case TSQuantifierZeroOrOne: + case TSQuantifierOne: + return TSQuantifierZeroOrOne; + break; + case TSQuantifierZeroOrMore: + case TSQuantifierOneOrMore: + return TSQuantifierZeroOrMore; + break; + }; + break; + case TSQuantifierZeroOrMore: + return TSQuantifierZeroOrMore; + case TSQuantifierOne: + switch (right) { + case TSQuantifierZero: + case TSQuantifierZeroOrOne: + return TSQuantifierZeroOrOne; + case TSQuantifierZeroOrMore: + return TSQuantifierZeroOrMore; + case TSQuantifierOne: + return TSQuantifierOne; + case TSQuantifierOneOrMore: + return TSQuantifierOneOrMore; + }; + break; + case TSQuantifierOneOrMore: + switch (right) { + case TSQuantifierZero: + case TSQuantifierZeroOrOne: + case TSQuantifierZeroOrMore: + return TSQuantifierZeroOrMore; + case TSQuantifierOne: + case TSQuantifierOneOrMore: + return TSQuantifierOneOrMore; + }; + break; + } + return TSQuantifierZero; // to make compiler happy, but all cases should be covered above! +} + +static TSQuantifier quantifier_add( + TSQuantifier left, + TSQuantifier right +) { + switch (left) + { + case TSQuantifierZero: + return right; + case TSQuantifierZeroOrOne: + switch (right) { + case TSQuantifierZero: + return TSQuantifierZeroOrOne; + case TSQuantifierZeroOrOne: + case TSQuantifierZeroOrMore: + return TSQuantifierZeroOrMore; + case TSQuantifierOne: + case TSQuantifierOneOrMore: + return TSQuantifierOneOrMore; + }; + break; + case TSQuantifierZeroOrMore: + switch (right) { + case TSQuantifierZero: + return TSQuantifierZeroOrMore; + case TSQuantifierZeroOrOne: + case TSQuantifierZeroOrMore: + return TSQuantifierZeroOrMore; + case TSQuantifierOne: + case TSQuantifierOneOrMore: + return TSQuantifierOneOrMore; + }; + break; + case TSQuantifierOne: + switch (right) { + case TSQuantifierZero: + return TSQuantifierOne; + case TSQuantifierZeroOrOne: + case TSQuantifierZeroOrMore: + case TSQuantifierOne: + case TSQuantifierOneOrMore: + return TSQuantifierOneOrMore; + }; + break; + case TSQuantifierOneOrMore: + return TSQuantifierOneOrMore; + } + return TSQuantifierZero; // to make compiler happy, but all cases should be covered above! +} + +// Create new capture quantifiers structure +static CaptureQuantifiers capture_quantifiers_new(void) { + return (CaptureQuantifiers) array_new(); +} + +// Delete capture quantifiers structure +static void capture_quantifiers_delete( + CaptureQuantifiers *self +) { + array_delete(self); +} + +// Clear capture quantifiers structure +static void capture_quantifiers_clear( + CaptureQuantifiers *self +) { + array_clear(self); +} + +// Replace capture quantifiers with the given quantifiers +static void capture_quantifiers_replace( + CaptureQuantifiers *self, + CaptureQuantifiers *quantifiers +) { + array_clear(self); + array_push_all(self, quantifiers); +} + +// Return capture quantifier for the given capture id +static TSQuantifier capture_quantifier_for_id( + const CaptureQuantifiers *self, + uint16_t id +) { + return (self->size <= id) ? TSQuantifierZero : (TSQuantifier) *array_get(self, id); +} + +// Add the given quantifier to the current value for id +static void capture_quantifiers_add_for_id( + CaptureQuantifiers *self, + uint16_t id, + TSQuantifier quantifier +) { + if (self->size <= id) { + array_grow_by(self, id + 1 - self->size); + } + uint8_t *own_quantifier = array_get(self, id); + *own_quantifier = (uint8_t) quantifier_add((TSQuantifier) *own_quantifier, quantifier); +} + +// Point-wise add the given quantifiers to the current values +static void capture_quantifiers_add_all( + CaptureQuantifiers *self, + CaptureQuantifiers *quantifiers +) { + if (self->size < quantifiers->size) { + array_grow_by(self, quantifiers->size - self->size); + } + for (uint16_t id = 0; id < quantifiers->size; id++) { + uint8_t *quantifier = array_get(quantifiers, id); + uint8_t *own_quantifier = array_get(self, id); + *own_quantifier = (uint8_t) quantifier_add((TSQuantifier) *own_quantifier, (TSQuantifier) *quantifier); + } +} + +// Join the given quantifier with the current values +static void capture_quantifiers_mul( + CaptureQuantifiers *self, + TSQuantifier quantifier +) { + for (uint16_t id = 0; id < self->size; id++) { + uint8_t *own_quantifier = array_get(self, id); + *own_quantifier = (uint8_t) quantifier_mul((TSQuantifier) *own_quantifier, quantifier); + } +} + +// Point-wise join the quantifiers from a list of alternatives with the current values +static void capture_quantifiers_join_all( + CaptureQuantifiers *self, + CaptureQuantifiers *quantifiers +) { + if (self->size < quantifiers->size) { + array_grow_by(self, quantifiers->size - self->size); + } + for (uint32_t id = 0; id < quantifiers->size; id++) { + uint8_t *quantifier = array_get(quantifiers, id); + uint8_t *own_quantifier = array_get(self, id); + *own_quantifier = (uint8_t) quantifier_join((TSQuantifier) *own_quantifier, (TSQuantifier) *quantifier); + } + for (uint32_t id = quantifiers->size; id < self->size; id++) { + uint8_t *own_quantifier = array_get(self, id); + *own_quantifier = (uint8_t) quantifier_join((TSQuantifier) *own_quantifier, TSQuantifierZero); + } +} + /************** * SymbolTable **************/ @@ -1779,11 +2042,15 @@ static TSQueryError ts_query__parse_predicate( // Read one S-expression pattern from the stream, and incorporate it into // the query's internal state machine representation. For nested patterns, // this function calls itself recursively. +// +// The caller is repsonsible for passing in a dedicated CaptureQuantifiers. +// These should not be shared between different calls to ts_query__parse_pattern! static TSQueryError ts_query__parse_pattern( TSQuery *self, Stream *stream, uint32_t depth, - bool is_immediate + bool is_immediate, + CaptureQuantifiers *capture_quantifiers ) { if (stream->next == 0) return TSQueryErrorSyntax; if (stream->next == ')' || stream->next == ']') return PARENT_DONE; @@ -1808,13 +2075,15 @@ static TSQueryError ts_query__parse_pattern( // Parse each branch, and add a placeholder step in between the branches. Array(uint32_t) branch_step_indices = array_new(); + CaptureQuantifiers branch_capture_quantifiers = capture_quantifiers_new(); for (;;) { uint32_t start_index = self->steps.size; TSQueryError e = ts_query__parse_pattern( self, stream, depth, - is_immediate + is_immediate, + &branch_capture_quantifiers ); if (e == PARENT_DONE) { @@ -1825,12 +2094,20 @@ static TSQueryError ts_query__parse_pattern( e = TSQueryErrorSyntax; } if (e) { + capture_quantifiers_delete(&branch_capture_quantifiers); array_delete(&branch_step_indices); return e; } + if(start_index == starting_step_index) { + capture_quantifiers_replace(capture_quantifiers, &branch_capture_quantifiers); + } else { + capture_quantifiers_join_all(capture_quantifiers, &branch_capture_quantifiers); + } + array_push(&branch_step_indices, start_index); array_push(&self->steps, query_step__new(0, depth, false)); + capture_quantifiers_clear(&branch_capture_quantifiers); } (void)array_pop(&self->steps); @@ -1846,6 +2123,7 @@ static TSQueryError ts_query__parse_pattern( end_step->is_dead_end = true; } + capture_quantifiers_delete(&branch_capture_quantifiers); array_delete(&branch_step_indices); } @@ -1860,6 +2138,7 @@ static TSQueryError ts_query__parse_pattern( // If this parenthesis is followed by a node, then it represents a grouped sequence. if (stream->next == '(' || stream->next == '"' || stream->next == '[') { bool child_is_immediate = false; + CaptureQuantifiers child_capture_quantifiers = capture_quantifiers_new(); for (;;) { if (stream->next == '.') { child_is_immediate = true; @@ -1870,7 +2149,8 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - child_is_immediate + child_is_immediate, + &child_capture_quantifiers ); if (e == PARENT_DONE) { if (stream->next == ')') { @@ -1879,10 +2159,17 @@ static TSQueryError ts_query__parse_pattern( } e = TSQueryErrorSyntax; } - if (e) return e; + if (e) { + capture_quantifiers_delete(&child_capture_quantifiers); + return e; + } + + capture_quantifiers_add_all(capture_quantifiers, &child_capture_quantifiers); child_is_immediate = false; + capture_quantifiers_clear(&child_capture_quantifiers); } + capture_quantifiers_delete(&child_capture_quantifiers); } // A dot/pound character indicates the start of a predicate. @@ -1971,12 +2258,16 @@ static TSQueryError ts_query__parse_pattern( uint16_t last_child_step_index = 0; uint16_t negated_field_count = 0; TSFieldId negated_field_ids[MAX_NEGATED_FIELD_COUNT]; + CaptureQuantifiers child_capture_quantifiers = capture_quantifiers_new(); for (;;) { // Parse a negated field assertion if (stream->next == '!') { stream_advance(stream); stream_skip_whitespace(stream); - if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; + if (!stream_is_ident_start(stream)) { + capture_quantifiers_delete(&child_capture_quantifiers); + return TSQueryErrorSyntax; + } const char *field_name = stream->input; stream_scan_identifier(stream); uint32_t length = stream->input - field_name; @@ -1989,6 +2280,7 @@ static TSQueryError ts_query__parse_pattern( ); if (!field_id) { stream->input = field_name; + capture_quantifiers_delete(&child_capture_quantifiers); return TSQueryErrorField; } @@ -2013,12 +2305,16 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth + 1, - child_is_immediate + child_is_immediate, + &child_capture_quantifiers ); if (e == PARENT_DONE) { if (stream->next == ')') { if (child_is_immediate) { - if (last_child_step_index == 0) return TSQueryErrorSyntax; + if (last_child_step_index == 0) { + capture_quantifiers_delete(&child_capture_quantifiers); + return TSQueryErrorSyntax; + } self->steps.contents[last_child_step_index].is_last_child = true; } @@ -2036,11 +2332,18 @@ static TSQueryError ts_query__parse_pattern( } e = TSQueryErrorSyntax; } - if (e) return e; + if (e) { + capture_quantifiers_delete(&child_capture_quantifiers); + return e; + } + + capture_quantifiers_add_all(capture_quantifiers, &child_capture_quantifiers); last_child_step_index = step_index; child_is_immediate = false; + capture_quantifiers_clear(&child_capture_quantifiers); } + capture_quantifiers_delete(&child_capture_quantifiers); } } @@ -2089,14 +2392,19 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); // Parse the pattern + CaptureQuantifiers field_capture_quantifiers = capture_quantifiers_new(); TSQueryError e = ts_query__parse_pattern( self, stream, depth, - is_immediate + is_immediate, + &field_capture_quantifiers ); - if (e == PARENT_DONE) return TSQueryErrorSyntax; - if (e) return e; + if (e) { + capture_quantifiers_delete(&field_capture_quantifiers); + if (e == PARENT_DONE) e = TSQueryErrorSyntax; + return e; + } // Add the field name to the first step of the pattern TSFieldId field_id = ts_language_field_id_for_name( @@ -2124,6 +2432,9 @@ static TSQueryError ts_query__parse_pattern( break; } } + + capture_quantifiers_add_all(capture_quantifiers, &field_capture_quantifiers); + capture_quantifiers_delete(&field_capture_quantifiers); } else { @@ -2133,9 +2444,12 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); // Parse suffixes modifiers for this pattern + TSQuantifier quantifier = TSQuantifierOne; for (;;) { // Parse the one-or-more operator. if (stream->next == '+') { + quantifier = quantifier_join(TSQuantifierOneOrMore, quantifier); + stream_advance(stream); stream_skip_whitespace(stream); @@ -2148,6 +2462,8 @@ static TSQueryError ts_query__parse_pattern( // Parse the zero-or-more repetition operator. else if (stream->next == '*') { + quantifier = quantifier_join(TSQuantifierZeroOrMore, quantifier); + stream_advance(stream); stream_skip_whitespace(stream); @@ -2166,6 +2482,8 @@ static TSQueryError ts_query__parse_pattern( // Parse the optional operator. else if (stream->next == '?') { + quantifier = quantifier_join(TSQuantifierZeroOrOne, quantifier); + stream_advance(stream); stream_skip_whitespace(stream); @@ -2192,6 +2510,9 @@ static TSQueryError ts_query__parse_pattern( length ); + // Add the capture quantifier + capture_quantifiers_add_for_id(capture_quantifiers, capture_id, TSQuantifierOne); + uint32_t step_index = starting_step_index; for (;;) { QueryStep *step = &self->steps.contents[step_index]; @@ -2215,6 +2536,8 @@ static TSQueryError ts_query__parse_pattern( } } + capture_quantifiers_mul(capture_quantifiers, quantifier); + return 0; } @@ -2239,6 +2562,7 @@ TSQuery *ts_query_new( .steps = array_new(), .pattern_map = array_new(), .captures = symbol_table_new(), + .capture_quantifiers = array_new(), .predicate_values = symbol_table_new(), .predicate_steps = array_new(), .patterns = array_new(), @@ -2263,7 +2587,8 @@ TSQuery *ts_query_new( .predicate_steps = (Slice) {.offset = start_predicate_step_index}, .start_byte = stream_offset(&stream), })); - *error_type = ts_query__parse_pattern(self, &stream, 0, false); + CaptureQuantifiers capture_quantifiers = capture_quantifiers_new(); + *error_type = ts_query__parse_pattern(self, &stream, 0, false, &capture_quantifiers); array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER, false)); QueryPattern *pattern = array_back(&self->patterns); @@ -2275,10 +2600,14 @@ TSQuery *ts_query_new( if (*error_type) { if (*error_type == PARENT_DONE) *error_type = TSQueryErrorSyntax; *error_offset = stream_offset(&stream); + capture_quantifiers_delete(&capture_quantifiers); ts_query_delete(self); return NULL; } + // Maintain a list of capture quantifiers for each pattern + array_push(&self->capture_quantifiers, capture_quantifiers); + // Maintain a map that can look up patterns for a given root symbol. uint16_t wildcard_root_alternative_index = NONE; for (;;) { @@ -2354,6 +2683,11 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->negated_fields); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); + for (uint32_t index = 0; index < self->capture_quantifiers.size; index++) { + CaptureQuantifiers *capture_quantifiers = array_get(&self->capture_quantifiers, index); + capture_quantifiers_delete(capture_quantifiers); + } + array_delete(&self->capture_quantifiers); ts_free(self); } } @@ -2378,6 +2712,15 @@ const char *ts_query_capture_name_for_id( return symbol_table_name_for_id(&self->captures, index, length); } +TSQuantifier ts_query_capture_quantifier_for_id( + const TSQuery *self, + uint32_t pattern_index, + uint32_t capture_index +) { + CaptureQuantifiers *capture_quantifiers = array_get(&self->capture_quantifiers, pattern_index); + return capture_quantifier_for_id(capture_quantifiers, capture_index); +} + const char *ts_query_string_value_for_id( const TSQuery *self, uint32_t index,