Add APIs for advancing a QueryCursor to an arbitrary position

This commit is contained in:
Max Brunsfeld 2021-05-24 21:07:59 -07:00
parent 8c3d1466ec
commit a61f25bc58
5 changed files with 178 additions and 50 deletions

View file

@ -3036,6 +3036,74 @@ fn test_query_text_callback_returns_chunks() {
});
}
#[test]
fn test_query_captures_advance_to_byte() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(identifier) @id
(array
"[" @lbracket
"]" @rbracket)
"#,
)
.unwrap();
let source = "[one, two, [three, four, five, six, seven, eight, nine, ten], eleven, twelve, thirteen]";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
cursor.set_byte_range(
source.find("two").unwrap() + 1,
source.find(", twelve").unwrap(),
);
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
// Retrieve four captures.
let mut results = Vec::new();
for (mat, capture_ix) in captures.by_ref().take(4) {
let capture = mat.captures[capture_ix as usize];
results.push((
query.capture_names()[capture.index as usize].as_str(),
&source[capture.node.byte_range()],
));
}
assert_eq!(
results,
vec![
("id", "two"),
("lbracket", "["),
("id", "three"),
("id", "four")
]
);
// Advance further ahead in the source, retrieve the remaining captures.
results.clear();
captures.advance_to_byte(source.find("ten").unwrap() + 1);
for (mat, capture_ix) in captures {
let capture = mat.captures[capture_ix as usize];
results.push((
query.capture_names()[capture.index as usize].as_str(),
&source[capture.node.byte_range()],
));
}
assert_eq!(
results,
vec![("id", "ten"), ("rbracket", "]"), ("id", "eleven"),]
);
// Advance past the last capture. There are no more captures.
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
captures.advance_to_byte(source.len());
assert!(captures.next().is_none());
assert!(captures.next().is_none());
});
}
#[test]
fn test_query_start_byte_for_pattern() {
let language = get_language("javascript");

View file

@ -1,6 +1,7 @@
/* automatically generated by rust-bindgen */
/* automatically generated by rust-bindgen 0.58.1 */
pub type __darwin_size_t = ::std::os::raw::c_ulong;
pub type size_t = usize;
pub type FILE = [u64; 19usize];
pub type TSSymbol = u16;
pub type TSFieldId = u16;
@ -31,11 +32,11 @@ pub struct TSQueryCursor {
}
pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0;
pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1;
pub type TSInputEncoding = u32;
pub type TSInputEncoding = ::std::os::raw::c_uint;
pub const TSSymbolType_TSSymbolTypeRegular: TSSymbolType = 0;
pub const TSSymbolType_TSSymbolTypeAnonymous: TSSymbolType = 1;
pub const TSSymbolType_TSSymbolTypeAuxiliary: TSSymbolType = 2;
pub type TSSymbolType = u32;
pub type TSSymbolType = ::std::os::raw::c_uint;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSPoint {
@ -66,7 +67,7 @@ pub struct TSInput {
}
pub const TSLogType_TSLogTypeParse: TSLogType = 0;
pub const TSLogType_TSLogTypeLex: TSLogType = 1;
pub type TSLogType = u32;
pub type TSLogType = ::std::os::raw::c_uint;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSLogger {
@ -120,7 +121,7 @@ pub struct TSQueryMatch {
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2;
pub type TSQueryPredicateStepType = u32;
pub type TSQueryPredicateStepType = ::std::os::raw::c_uint;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryPredicateStep {
@ -133,7 +134,7 @@ pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4;
pub const TSQueryError_TSQueryErrorStructure: TSQueryError = 5;
pub type TSQueryError = u32;
pub type TSQueryError = ::std::os::raw::c_uint;
extern "C" {
#[doc = " Create a new parser."]
pub fn ts_parser_new() -> *mut TSParser;
@ -148,13 +149,13 @@ extern "C" {
#[doc = " Returns a boolean indicating whether or not the language was successfully"]
#[doc = " assigned. True means assignment succeeded. False means there was a version"]
#[doc = " mismatch: the language was generated with an incompatible version of the"]
#[doc = " Tree-sitter CLI. Check the language\'s version using `ts_language_version`"]
#[doc = " and compare it to this library\'s `TREE_SITTER_LANGUAGE_VERSION` and"]
#[doc = " Tree-sitter CLI. Check the language's version using `ts_language_version`"]
#[doc = " and compare it to this library's `TREE_SITTER_LANGUAGE_VERSION` and"]
#[doc = " `TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION` constants."]
pub fn ts_parser_set_language(self_: *mut TSParser, language: *const TSLanguage) -> bool;
}
extern "C" {
#[doc = " Get the parser\'s current language."]
#[doc = " Get the parser's current language."]
pub fn ts_parser_language(self_: *const TSParser) -> *const TSLanguage;
}
extern "C" {
@ -167,7 +168,7 @@ extern "C" {
#[doc = ""]
#[doc = " The second and third parameters specify the location and length of an array"]
#[doc = " of ranges. The parser does *not* take ownership of these ranges; it copies"]
#[doc = " the data, so it doesn\'t matter how these ranges are allocated."]
#[doc = " the data, so it doesn't matter how these ranges are allocated."]
#[doc = ""]
#[doc = " If `length` is zero, then the entire document will be parsed. Otherwise,"]
#[doc = " the given ranges must be ordered from earliest to latest in the document,"]
@ -266,7 +267,7 @@ extern "C" {
#[doc = ""]
#[doc = " If the parser previously failed because of a timeout or a cancellation, then"]
#[doc = " by default, it will resume where it left off on the next call to"]
#[doc = " `ts_parser_parse` or other parsing functions. If you don\'t want to resume,"]
#[doc = " `ts_parser_parse` or other parsing functions. If you don't want to resume,"]
#[doc = " and instead intend to use this parser to parse some other document, you must"]
#[doc = " call `ts_parser_reset` first."]
pub fn ts_parser_reset(self_: *mut TSParser);
@ -284,16 +285,16 @@ extern "C" {
pub fn ts_parser_timeout_micros(self_: *const TSParser) -> u64;
}
extern "C" {
#[doc = " Set the parser\'s current cancellation flag pointer."]
#[doc = " Set the parser's current cancellation flag pointer."]
#[doc = ""]
#[doc = " If a non-null pointer is assigned, then the parser will periodically read"]
#[doc = " from this pointer during parsing. If it reads a non-zero value, it will"]
#[doc = " halt early, returning NULL. See `ts_parser_parse` for more information."]
pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const usize);
pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const size_t);
}
extern "C" {
#[doc = " Get the parser\'s current cancellation flag pointer."]
pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const usize;
#[doc = " Get the parser's current cancellation flag pointer."]
pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const size_t;
}
extern "C" {
#[doc = " Set the logger that a parser should use during parsing."]
@ -304,7 +305,7 @@ extern "C" {
pub fn ts_parser_set_logger(self_: *mut TSParser, logger: TSLogger);
}
extern "C" {
#[doc = " Get the parser\'s current logger."]
#[doc = " Get the parser's current logger."]
pub fn ts_parser_logger(self_: *const TSParser) -> TSLogger;
}
extern "C" {
@ -346,7 +347,7 @@ extern "C" {
#[doc = " document, returning an array of ranges whose syntactic structure has changed."]
#[doc = ""]
#[doc = " For this to work correctly, the old syntax tree must have been edited such"]
#[doc = " that its ranges match up to the new tree. Generally, you\'ll want to call"]
#[doc = " that its ranges match up to the new tree. Generally, you'll want to call"]
#[doc = " this function right after calling one of the `ts_parser_parse` functions."]
#[doc = " You need to pass the old tree that was passed to parse, as well as the new"]
#[doc = " tree that was returned from that function."]
@ -365,27 +366,27 @@ extern "C" {
pub fn ts_tree_print_dot_graph(arg1: *const TSTree, arg2: *mut FILE);
}
extern "C" {
#[doc = " Get the node\'s type as a null-terminated string."]
#[doc = " Get the node's type as a null-terminated string."]
pub fn ts_node_type(arg1: TSNode) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Get the node\'s type as a numerical id."]
#[doc = " Get the node's type as a numerical id."]
pub fn ts_node_symbol(arg1: TSNode) -> TSSymbol;
}
extern "C" {
#[doc = " Get the node\'s start byte."]
#[doc = " Get the node's start byte."]
pub fn ts_node_start_byte(arg1: TSNode) -> u32;
}
extern "C" {
#[doc = " Get the node\'s start position in terms of rows and columns."]
#[doc = " Get the node's start position in terms of rows and columns."]
pub fn ts_node_start_point(arg1: TSNode) -> TSPoint;
}
extern "C" {
#[doc = " Get the node\'s end byte."]
#[doc = " Get the node's end byte."]
pub fn ts_node_end_byte(arg1: TSNode) -> u32;
}
extern "C" {
#[doc = " Get the node\'s end position in terms of rows and columns."]
#[doc = " Get the node's end position in terms of rows and columns."]
pub fn ts_node_end_point(arg1: TSNode) -> TSPoint;
}
extern "C" {
@ -426,11 +427,11 @@ extern "C" {
pub fn ts_node_has_error(arg1: TSNode) -> bool;
}
extern "C" {
#[doc = " Get the node\'s immediate parent."]
#[doc = " Get the node's immediate parent."]
pub fn ts_node_parent(arg1: TSNode) -> TSNode;
}
extern "C" {
#[doc = " Get the node\'s child at the given index, where zero represents the first"]
#[doc = " Get the node's child at the given index, where zero represents the first"]
#[doc = " child."]
pub fn ts_node_child(arg1: TSNode, arg2: u32) -> TSNode;
}
@ -440,23 +441,23 @@ extern "C" {
pub fn ts_node_field_name_for_child(arg1: TSNode, arg2: u32) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Get the node\'s number of children."]
#[doc = " Get the node's number of children."]
pub fn ts_node_child_count(arg1: TSNode) -> u32;
}
extern "C" {
#[doc = " Get the node\'s *named* child at the given index."]
#[doc = " Get the node's *named* child at the given index."]
#[doc = ""]
#[doc = " See also `ts_node_is_named`."]
pub fn ts_node_named_child(arg1: TSNode, arg2: u32) -> TSNode;
}
extern "C" {
#[doc = " Get the node\'s number of *named* children."]
#[doc = " Get the node's number of *named* children."]
#[doc = ""]
#[doc = " See also `ts_node_is_named`."]
pub fn ts_node_named_child_count(arg1: TSNode) -> u32;
}
extern "C" {
#[doc = " Get the node\'s child with the given field name."]
#[doc = " Get the node's child with the given field name."]
pub fn ts_node_child_by_field_name(
self_: TSNode,
field_name: *const ::std::os::raw::c_char,
@ -464,32 +465,32 @@ extern "C" {
) -> TSNode;
}
extern "C" {
#[doc = " Get the node\'s child with the given numerical field id."]
#[doc = " Get the node's child with the given numerical field id."]
#[doc = ""]
#[doc = " You can convert a field name to an id using the"]
#[doc = " `ts_language_field_id_for_name` function."]
pub fn ts_node_child_by_field_id(arg1: TSNode, arg2: TSFieldId) -> TSNode;
}
extern "C" {
#[doc = " Get the node\'s next / previous sibling."]
#[doc = " Get the node's next / previous sibling."]
pub fn ts_node_next_sibling(arg1: TSNode) -> TSNode;
}
extern "C" {
pub fn ts_node_prev_sibling(arg1: TSNode) -> TSNode;
}
extern "C" {
#[doc = " Get the node\'s next / previous *named* sibling."]
#[doc = " Get the node's next / previous *named* sibling."]
pub fn ts_node_next_named_sibling(arg1: TSNode) -> TSNode;
}
extern "C" {
pub fn ts_node_prev_named_sibling(arg1: TSNode) -> TSNode;
}
extern "C" {
#[doc = " Get the node\'s first child that extends beyond the given byte offset."]
#[doc = " Get the node's first child that extends beyond the given byte offset."]
pub fn ts_node_first_child_for_byte(arg1: TSNode, arg2: u32) -> TSNode;
}
extern "C" {
#[doc = " Get the node\'s first named child that extends beyond the given byte offset."]
#[doc = " Get the node's first named child that extends beyond the given byte offset."]
pub fn ts_node_first_named_child_for_byte(arg1: TSNode, arg2: u32) -> TSNode;
}
extern "C" {
@ -544,22 +545,22 @@ extern "C" {
pub fn ts_tree_cursor_reset(arg1: *mut TSTreeCursor, arg2: TSNode);
}
extern "C" {
#[doc = " Get the tree cursor\'s current node."]
#[doc = " Get the tree cursor's current node."]
pub fn ts_tree_cursor_current_node(arg1: *const TSTreeCursor) -> TSNode;
}
extern "C" {
#[doc = " Get the field name of the tree cursor\'s current node."]
#[doc = " Get the field name of the tree cursor's current node."]
#[doc = ""]
#[doc = " This returns `NULL` if the current node doesn\'t have a field."]
#[doc = " This returns `NULL` if the current node doesn't have a field."]
#[doc = " See also `ts_node_child_by_field_name`."]
pub fn ts_tree_cursor_current_field_name(
arg1: *const TSTreeCursor,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Get the field name of the tree cursor\'s current node."]
#[doc = " Get the field name of the tree cursor's current node."]
#[doc = ""]
#[doc = " This returns zero if the current node doesn\'t have a field."]
#[doc = " This returns zero if the current node doesn't have a field."]
#[doc = " See also `ts_node_child_by_field_id`, `ts_language_field_id_for_name`."]
pub fn ts_tree_cursor_current_field_id(arg1: *const TSTreeCursor) -> TSFieldId;
}
@ -628,7 +629,7 @@ extern "C" {
pub fn ts_query_string_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the byte offset where the given pattern starts in the query\'s source."]
#[doc = " Get the byte offset where the given pattern starts in the query's source."]
#[doc = ""]
#[doc = " This can be useful when combining queries by concatenating their source"]
#[doc = " code strings."]
@ -659,9 +660,9 @@ extern "C" {
pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool;
}
extern "C" {
#[doc = " Get the name and length of one of the query\'s captures, or one of the"]
#[doc = " query\'s string literals. Each capture and string is associated with a"]
#[doc = " numeric id based on the order that it appeared in the query\'s source."]
#[doc = " Get the name and length of one of the query's captures, or one of the"]
#[doc = " query's string literals. Each capture and string is associated with a"]
#[doc = " numeric id based on the order that it appeared in the query's source."]
pub fn ts_query_capture_name_for_id(
arg1: *const TSQuery,
id: u32,
@ -708,10 +709,10 @@ extern "C" {
#[doc = " captures that appear *before* some of the captures from a previous match."]
#[doc = " 2. Repeatedly call `ts_query_cursor_next_capture` to iterate over all of the"]
#[doc = " individual *captures* in the order that they appear. This is useful if"]
#[doc = " don\'t care about which pattern matched, and just want a single ordered"]
#[doc = " don't care about which pattern matched, and just want a single ordered"]
#[doc = " sequence of captures."]
#[doc = ""]
#[doc = " If you don\'t care about consuming all of the results, you can stop calling"]
#[doc = " If you don't care about consuming all of the results, you can stop calling"]
#[doc = " `ts_query_cursor_next_match` or `ts_query_cursor_next_capture` at any point."]
#[doc = " You can then start executing another query on another node by calling"]
#[doc = " `ts_query_cursor_exec` again."]
@ -736,8 +737,18 @@ extern "C" {
pub fn ts_query_cursor_did_exceed_match_limit(arg1: *const TSQueryCursor) -> bool;
}
extern "C" {
#[doc = " Set the range of bytes or (row, column) positions in which the query"]
#[doc = " Get or set the range of bytes or (row, column) positions in which the query"]
#[doc = " will be executed."]
pub fn ts_query_cursor_byte_range(arg1: *const TSQueryCursor, arg2: *mut u32, arg3: *mut u32);
}
extern "C" {
pub fn ts_query_cursor_point_range(
arg1: *const TSQueryCursor,
arg2: *mut TSPoint,
arg3: *mut TSPoint,
);
}
extern "C" {
pub fn ts_query_cursor_set_byte_range(arg1: *mut TSQueryCursor, arg2: u32, arg3: u32);
}
extern "C" {
@ -757,7 +768,7 @@ extern "C" {
#[doc = " Advance to the next capture of the currently running query."]
#[doc = ""]
#[doc = " If there is a capture, write its match to `*match` and its index within"]
#[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."]
#[doc = " the matche's capture list to `*capture_index`. Otherwise, return `false`."]
pub fn ts_query_cursor_next_capture(
arg1: *mut TSQueryCursor,
match_: *mut TSQueryMatch,

View file

@ -1809,6 +1809,34 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> {
}
}
impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> {
pub fn advance_to_byte(&mut self, offset: usize) {
unsafe {
let mut current_start = 0u32;
let mut current_end = 0u32;
ffi::ts_query_cursor_byte_range(
self.ptr,
&mut current_start as *mut u32,
&mut current_end as *mut u32,
);
ffi::ts_query_cursor_set_byte_range(self.ptr, offset as u32, current_end);
}
}
pub fn advance_to_point(&mut self, point: Point) {
unsafe {
let mut current_start = ffi::TSPoint { row: 0, column: 0 };
let mut current_end = current_start;
ffi::ts_query_cursor_point_range(
self.ptr,
&mut current_start as *mut _,
&mut current_end as *mut _,
);
ffi::ts_query_cursor_set_point_range(self.ptr, point.into(), current_end);
}
}
}
impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> {
type Item = (QueryMatch<'a, 'tree>, usize);

View file

@ -809,9 +809,11 @@ void ts_query_cursor_exec(TSQueryCursor *, const TSQuery *, TSNode);
bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *);
/**
* Set the range of bytes or (row, column) positions in which the query
* Get or set the range of bytes or (row, column) positions in which the query
* will be executed.
*/
void ts_query_cursor_byte_range(const TSQueryCursor *, uint32_t *, uint32_t *);
void ts_query_cursor_point_range(const TSQueryCursor *, TSPoint *, TSPoint *);
void ts_query_cursor_set_byte_range(TSQueryCursor *, uint32_t, uint32_t);
void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint);

View file

@ -2302,6 +2302,24 @@ void ts_query_cursor_exec(
self->did_exceed_match_limit = false;
}
void ts_query_cursor_byte_range(
const TSQueryCursor *self,
uint32_t *start_byte,
uint32_t *end_byte
) {
*start_byte = self->start_byte;
*end_byte = self->end_byte;
}
void ts_query_cursor_point_range(
const TSQueryCursor *self,
TSPoint *start_point,
TSPoint *end_point
) {
*start_point = self->start_point;
*end_point = self->end_point;
}
void ts_query_cursor_set_byte_range(
TSQueryCursor *self,
uint32_t start_byte,
@ -2639,7 +2657,7 @@ static inline bool ts_query_cursor__advance(
} else if (ts_tree_cursor_goto_parent(&self->cursor)) {
self->depth--;
} else {
LOG("halt at root");
LOG("halt at root\n");
self->halted = true;
}
@ -2696,6 +2714,7 @@ static inline bool ts_query_cursor__advance(
if (!ts_tree_cursor_goto_next_sibling(&self->cursor)) {
self->ascending = true;
}
LOG("skip until start of range\n");
continue;
}
@ -2704,7 +2723,7 @@ static inline bool ts_query_cursor__advance(
self->end_byte <= ts_node_start_byte(node) ||
point_lte(self->end_point, ts_node_start_point(node))
) {
LOG("halt at end of range");
LOG("halt at end of range\n");
self->halted = true;
continue;
}