diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 8e2c1313..2245f4f9 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -3036,6 +3036,74 @@ fn test_query_text_callback_returns_chunks() { }); } +#[test] +fn test_query_captures_advance_to_byte() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (identifier) @id + (array + "[" @lbracket + "]" @rbracket) + "#, + ) + .unwrap(); + let source = "[one, two, [three, four, five, six, seven, eight, nine, ten], eleven, twelve, thirteen]"; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + cursor.set_byte_range( + source.find("two").unwrap() + 1, + source.find(", twelve").unwrap(), + ); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + + // Retrieve four captures. + let mut results = Vec::new(); + for (mat, capture_ix) in captures.by_ref().take(4) { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![ + ("id", "two"), + ("lbracket", "["), + ("id", "three"), + ("id", "four") + ] + ); + + // Advance further ahead in the source, retrieve the remaining captures. + results.clear(); + captures.advance_to_byte(source.find("ten").unwrap() + 1); + for (mat, capture_ix) in captures { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![("id", "ten"), ("rbracket", "]"), ("id", "eleven"),] + ); + + // Advance past the last capture. There are no more captures. + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + captures.advance_to_byte(source.len()); + assert!(captures.next().is_none()); + assert!(captures.next().is_none()); + }); +} + #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 50da12fc..a729c12c 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -1,6 +1,7 @@ -/* automatically generated by rust-bindgen */ +/* automatically generated by rust-bindgen 0.58.1 */ pub type __darwin_size_t = ::std::os::raw::c_ulong; +pub type size_t = usize; pub type FILE = [u64; 19usize]; pub type TSSymbol = u16; pub type TSFieldId = u16; @@ -31,11 +32,11 @@ pub struct TSQueryCursor { } pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0; pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1; -pub type TSInputEncoding = u32; +pub type TSInputEncoding = ::std::os::raw::c_uint; pub const TSSymbolType_TSSymbolTypeRegular: TSSymbolType = 0; pub const TSSymbolType_TSSymbolTypeAnonymous: TSSymbolType = 1; pub const TSSymbolType_TSSymbolTypeAuxiliary: TSSymbolType = 2; -pub type TSSymbolType = u32; +pub type TSSymbolType = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSPoint { @@ -66,7 +67,7 @@ pub struct TSInput { } pub const TSLogType_TSLogTypeParse: TSLogType = 0; pub const TSLogType_TSLogTypeLex: TSLogType = 1; -pub type TSLogType = u32; +pub type TSLogType = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSLogger { @@ -120,7 +121,7 @@ pub struct TSQueryMatch { pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0; pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1; pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2; -pub type TSQueryPredicateStepType = u32; +pub type TSQueryPredicateStepType = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSQueryPredicateStep { @@ -133,7 +134,7 @@ pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4; pub const TSQueryError_TSQueryErrorStructure: TSQueryError = 5; -pub type TSQueryError = u32; +pub type TSQueryError = ::std::os::raw::c_uint; extern "C" { #[doc = " Create a new parser."] pub fn ts_parser_new() -> *mut TSParser; @@ -148,13 +149,13 @@ extern "C" { #[doc = " Returns a boolean indicating whether or not the language was successfully"] #[doc = " assigned. True means assignment succeeded. False means there was a version"] #[doc = " mismatch: the language was generated with an incompatible version of the"] - #[doc = " Tree-sitter CLI. Check the language\'s version using `ts_language_version`"] - #[doc = " and compare it to this library\'s `TREE_SITTER_LANGUAGE_VERSION` and"] + #[doc = " Tree-sitter CLI. Check the language's version using `ts_language_version`"] + #[doc = " and compare it to this library's `TREE_SITTER_LANGUAGE_VERSION` and"] #[doc = " `TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION` constants."] pub fn ts_parser_set_language(self_: *mut TSParser, language: *const TSLanguage) -> bool; } extern "C" { - #[doc = " Get the parser\'s current language."] + #[doc = " Get the parser's current language."] pub fn ts_parser_language(self_: *const TSParser) -> *const TSLanguage; } extern "C" { @@ -167,7 +168,7 @@ extern "C" { #[doc = ""] #[doc = " The second and third parameters specify the location and length of an array"] #[doc = " of ranges. The parser does *not* take ownership of these ranges; it copies"] - #[doc = " the data, so it doesn\'t matter how these ranges are allocated."] + #[doc = " the data, so it doesn't matter how these ranges are allocated."] #[doc = ""] #[doc = " If `length` is zero, then the entire document will be parsed. Otherwise,"] #[doc = " the given ranges must be ordered from earliest to latest in the document,"] @@ -266,7 +267,7 @@ extern "C" { #[doc = ""] #[doc = " If the parser previously failed because of a timeout or a cancellation, then"] #[doc = " by default, it will resume where it left off on the next call to"] - #[doc = " `ts_parser_parse` or other parsing functions. If you don\'t want to resume,"] + #[doc = " `ts_parser_parse` or other parsing functions. If you don't want to resume,"] #[doc = " and instead intend to use this parser to parse some other document, you must"] #[doc = " call `ts_parser_reset` first."] pub fn ts_parser_reset(self_: *mut TSParser); @@ -284,16 +285,16 @@ extern "C" { pub fn ts_parser_timeout_micros(self_: *const TSParser) -> u64; } extern "C" { - #[doc = " Set the parser\'s current cancellation flag pointer."] + #[doc = " Set the parser's current cancellation flag pointer."] #[doc = ""] #[doc = " If a non-null pointer is assigned, then the parser will periodically read"] #[doc = " from this pointer during parsing. If it reads a non-zero value, it will"] #[doc = " halt early, returning NULL. See `ts_parser_parse` for more information."] - pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const usize); + pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const size_t); } extern "C" { - #[doc = " Get the parser\'s current cancellation flag pointer."] - pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const usize; + #[doc = " Get the parser's current cancellation flag pointer."] + pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const size_t; } extern "C" { #[doc = " Set the logger that a parser should use during parsing."] @@ -304,7 +305,7 @@ extern "C" { pub fn ts_parser_set_logger(self_: *mut TSParser, logger: TSLogger); } extern "C" { - #[doc = " Get the parser\'s current logger."] + #[doc = " Get the parser's current logger."] pub fn ts_parser_logger(self_: *const TSParser) -> TSLogger; } extern "C" { @@ -346,7 +347,7 @@ extern "C" { #[doc = " document, returning an array of ranges whose syntactic structure has changed."] #[doc = ""] #[doc = " For this to work correctly, the old syntax tree must have been edited such"] - #[doc = " that its ranges match up to the new tree. Generally, you\'ll want to call"] + #[doc = " that its ranges match up to the new tree. Generally, you'll want to call"] #[doc = " this function right after calling one of the `ts_parser_parse` functions."] #[doc = " You need to pass the old tree that was passed to parse, as well as the new"] #[doc = " tree that was returned from that function."] @@ -365,27 +366,27 @@ extern "C" { pub fn ts_tree_print_dot_graph(arg1: *const TSTree, arg2: *mut FILE); } extern "C" { - #[doc = " Get the node\'s type as a null-terminated string."] + #[doc = " Get the node's type as a null-terminated string."] pub fn ts_node_type(arg1: TSNode) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the node\'s type as a numerical id."] + #[doc = " Get the node's type as a numerical id."] pub fn ts_node_symbol(arg1: TSNode) -> TSSymbol; } extern "C" { - #[doc = " Get the node\'s start byte."] + #[doc = " Get the node's start byte."] pub fn ts_node_start_byte(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s start position in terms of rows and columns."] + #[doc = " Get the node's start position in terms of rows and columns."] pub fn ts_node_start_point(arg1: TSNode) -> TSPoint; } extern "C" { - #[doc = " Get the node\'s end byte."] + #[doc = " Get the node's end byte."] pub fn ts_node_end_byte(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s end position in terms of rows and columns."] + #[doc = " Get the node's end position in terms of rows and columns."] pub fn ts_node_end_point(arg1: TSNode) -> TSPoint; } extern "C" { @@ -426,11 +427,11 @@ extern "C" { pub fn ts_node_has_error(arg1: TSNode) -> bool; } extern "C" { - #[doc = " Get the node\'s immediate parent."] + #[doc = " Get the node's immediate parent."] pub fn ts_node_parent(arg1: TSNode) -> TSNode; } extern "C" { - #[doc = " Get the node\'s child at the given index, where zero represents the first"] + #[doc = " Get the node's child at the given index, where zero represents the first"] #[doc = " child."] pub fn ts_node_child(arg1: TSNode, arg2: u32) -> TSNode; } @@ -440,23 +441,23 @@ extern "C" { pub fn ts_node_field_name_for_child(arg1: TSNode, arg2: u32) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the node\'s number of children."] + #[doc = " Get the node's number of children."] pub fn ts_node_child_count(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s *named* child at the given index."] + #[doc = " Get the node's *named* child at the given index."] #[doc = ""] #[doc = " See also `ts_node_is_named`."] pub fn ts_node_named_child(arg1: TSNode, arg2: u32) -> TSNode; } extern "C" { - #[doc = " Get the node\'s number of *named* children."] + #[doc = " Get the node's number of *named* children."] #[doc = ""] #[doc = " See also `ts_node_is_named`."] pub fn ts_node_named_child_count(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s child with the given field name."] + #[doc = " Get the node's child with the given field name."] pub fn ts_node_child_by_field_name( self_: TSNode, field_name: *const ::std::os::raw::c_char, @@ -464,32 +465,32 @@ extern "C" { ) -> TSNode; } extern "C" { - #[doc = " Get the node\'s child with the given numerical field id."] + #[doc = " Get the node's child with the given numerical field id."] #[doc = ""] #[doc = " You can convert a field name to an id using the"] #[doc = " `ts_language_field_id_for_name` function."] pub fn ts_node_child_by_field_id(arg1: TSNode, arg2: TSFieldId) -> TSNode; } extern "C" { - #[doc = " Get the node\'s next / previous sibling."] + #[doc = " Get the node's next / previous sibling."] pub fn ts_node_next_sibling(arg1: TSNode) -> TSNode; } extern "C" { pub fn ts_node_prev_sibling(arg1: TSNode) -> TSNode; } extern "C" { - #[doc = " Get the node\'s next / previous *named* sibling."] + #[doc = " Get the node's next / previous *named* sibling."] pub fn ts_node_next_named_sibling(arg1: TSNode) -> TSNode; } extern "C" { pub fn ts_node_prev_named_sibling(arg1: TSNode) -> TSNode; } extern "C" { - #[doc = " Get the node\'s first child that extends beyond the given byte offset."] + #[doc = " Get the node's first child that extends beyond the given byte offset."] pub fn ts_node_first_child_for_byte(arg1: TSNode, arg2: u32) -> TSNode; } extern "C" { - #[doc = " Get the node\'s first named child that extends beyond the given byte offset."] + #[doc = " Get the node's first named child that extends beyond the given byte offset."] pub fn ts_node_first_named_child_for_byte(arg1: TSNode, arg2: u32) -> TSNode; } extern "C" { @@ -544,22 +545,22 @@ extern "C" { pub fn ts_tree_cursor_reset(arg1: *mut TSTreeCursor, arg2: TSNode); } extern "C" { - #[doc = " Get the tree cursor\'s current node."] + #[doc = " Get the tree cursor's current node."] pub fn ts_tree_cursor_current_node(arg1: *const TSTreeCursor) -> TSNode; } extern "C" { - #[doc = " Get the field name of the tree cursor\'s current node."] + #[doc = " Get the field name of the tree cursor's current node."] #[doc = ""] - #[doc = " This returns `NULL` if the current node doesn\'t have a field."] + #[doc = " This returns `NULL` if the current node doesn't have a field."] #[doc = " See also `ts_node_child_by_field_name`."] pub fn ts_tree_cursor_current_field_name( arg1: *const TSTreeCursor, ) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the field name of the tree cursor\'s current node."] + #[doc = " Get the field name of the tree cursor's current node."] #[doc = ""] - #[doc = " This returns zero if the current node doesn\'t have a field."] + #[doc = " This returns zero if the current node doesn't have a field."] #[doc = " See also `ts_node_child_by_field_id`, `ts_language_field_id_for_name`."] pub fn ts_tree_cursor_current_field_id(arg1: *const TSTreeCursor) -> TSFieldId; } @@ -628,7 +629,7 @@ extern "C" { pub fn ts_query_string_count(arg1: *const TSQuery) -> u32; } extern "C" { - #[doc = " Get the byte offset where the given pattern starts in the query\'s source."] + #[doc = " Get the byte offset where the given pattern starts in the query's source."] #[doc = ""] #[doc = " This can be useful when combining queries by concatenating their source"] #[doc = " code strings."] @@ -659,9 +660,9 @@ extern "C" { pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool; } extern "C" { - #[doc = " Get the name and length of one of the query\'s captures, or one of the"] - #[doc = " query\'s string literals. Each capture and string is associated with a"] - #[doc = " numeric id based on the order that it appeared in the query\'s source."] + #[doc = " Get the name and length of one of the query's captures, or one of the"] + #[doc = " query's string literals. Each capture and string is associated with a"] + #[doc = " numeric id based on the order that it appeared in the query's source."] pub fn ts_query_capture_name_for_id( arg1: *const TSQuery, id: u32, @@ -708,10 +709,10 @@ extern "C" { #[doc = " captures that appear *before* some of the captures from a previous match."] #[doc = " 2. Repeatedly call `ts_query_cursor_next_capture` to iterate over all of the"] #[doc = " individual *captures* in the order that they appear. This is useful if"] - #[doc = " don\'t care about which pattern matched, and just want a single ordered"] + #[doc = " don't care about which pattern matched, and just want a single ordered"] #[doc = " sequence of captures."] #[doc = ""] - #[doc = " If you don\'t care about consuming all of the results, you can stop calling"] + #[doc = " If you don't care about consuming all of the results, you can stop calling"] #[doc = " `ts_query_cursor_next_match` or `ts_query_cursor_next_capture` at any point."] #[doc = " You can then start executing another query on another node by calling"] #[doc = " `ts_query_cursor_exec` again."] @@ -736,8 +737,18 @@ extern "C" { pub fn ts_query_cursor_did_exceed_match_limit(arg1: *const TSQueryCursor) -> bool; } extern "C" { - #[doc = " Set the range of bytes or (row, column) positions in which the query"] + #[doc = " Get or set the range of bytes or (row, column) positions in which the query"] #[doc = " will be executed."] + pub fn ts_query_cursor_byte_range(arg1: *const TSQueryCursor, arg2: *mut u32, arg3: *mut u32); +} +extern "C" { + pub fn ts_query_cursor_point_range( + arg1: *const TSQueryCursor, + arg2: *mut TSPoint, + arg3: *mut TSPoint, + ); +} +extern "C" { pub fn ts_query_cursor_set_byte_range(arg1: *mut TSQueryCursor, arg2: u32, arg3: u32); } extern "C" { @@ -757,7 +768,7 @@ extern "C" { #[doc = " Advance to the next capture of the currently running query."] #[doc = ""] #[doc = " If there is a capture, write its match to `*match` and its index within"] - #[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."] + #[doc = " the matche's capture list to `*capture_index`. Otherwise, return `false`."] pub fn ts_query_cursor_next_capture( arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch, diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 92e3c9d6..ea99c067 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1809,6 +1809,34 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { } } +impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { + pub fn advance_to_byte(&mut self, offset: usize) { + unsafe { + let mut current_start = 0u32; + let mut current_end = 0u32; + ffi::ts_query_cursor_byte_range( + self.ptr, + &mut current_start as *mut u32, + &mut current_end as *mut u32, + ); + ffi::ts_query_cursor_set_byte_range(self.ptr, offset as u32, current_end); + } + } + + pub fn advance_to_point(&mut self, point: Point) { + unsafe { + let mut current_start = ffi::TSPoint { row: 0, column: 0 }; + let mut current_end = current_start; + ffi::ts_query_cursor_point_range( + self.ptr, + &mut current_start as *mut _, + &mut current_end as *mut _, + ); + ffi::ts_query_cursor_set_point_range(self.ptr, point.into(), current_end); + } + } +} + impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> { type Item = (QueryMatch<'a, 'tree>, usize); diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 43315415..01d84a6e 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -809,9 +809,11 @@ void ts_query_cursor_exec(TSQueryCursor *, const TSQuery *, TSNode); bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *); /** - * Set the range of bytes or (row, column) positions in which the query + * Get or set the range of bytes or (row, column) positions in which the query * will be executed. */ +void ts_query_cursor_byte_range(const TSQueryCursor *, uint32_t *, uint32_t *); +void ts_query_cursor_point_range(const TSQueryCursor *, TSPoint *, TSPoint *); void ts_query_cursor_set_byte_range(TSQueryCursor *, uint32_t, uint32_t); void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); diff --git a/lib/src/query.c b/lib/src/query.c index 65dbe1fe..278b3a3c 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2302,6 +2302,24 @@ void ts_query_cursor_exec( self->did_exceed_match_limit = false; } +void ts_query_cursor_byte_range( + const TSQueryCursor *self, + uint32_t *start_byte, + uint32_t *end_byte +) { + *start_byte = self->start_byte; + *end_byte = self->end_byte; +} + +void ts_query_cursor_point_range( + const TSQueryCursor *self, + TSPoint *start_point, + TSPoint *end_point +) { + *start_point = self->start_point; + *end_point = self->end_point; +} + void ts_query_cursor_set_byte_range( TSQueryCursor *self, uint32_t start_byte, @@ -2639,7 +2657,7 @@ static inline bool ts_query_cursor__advance( } else if (ts_tree_cursor_goto_parent(&self->cursor)) { self->depth--; } else { - LOG("halt at root"); + LOG("halt at root\n"); self->halted = true; } @@ -2696,6 +2714,7 @@ static inline bool ts_query_cursor__advance( if (!ts_tree_cursor_goto_next_sibling(&self->cursor)) { self->ascending = true; } + LOG("skip until start of range\n"); continue; } @@ -2704,7 +2723,7 @@ static inline bool ts_query_cursor__advance( self->end_byte <= ts_node_start_byte(node) || point_lte(self->end_point, ts_node_start_point(node)) ) { - LOG("halt at end of range"); + LOG("halt at end of range\n"); self->halted = true; continue; }