From 6b1ebd3d29b44890105d35e9d181dbdfbfa69a8c Mon Sep 17 00:00:00 2001 From: Lukas Seidel Date: Sun, 29 Sep 2024 23:34:48 +0200 Subject: [PATCH] feat!: implement `StreamingIterator` instead of `Iterator` for `QueryMatches` and `QueryCaptures` This fixes UB when either `QueryMatches` or `QueryCaptures` had collect called on it. Co-authored-by: Amaan Qureshi --- Cargo.lock | 10 ++ Cargo.toml | 1 + cli/Cargo.toml | 1 + cli/src/query.rs | 13 ++- cli/src/tests/helpers/query_helpers.rs | 41 ++++---- cli/src/tests/query_test.rs | 131 ++++++++++++++++--------- cli/src/tests/text_provider_test.rs | 5 +- cli/src/tests/wasm_language_test.rs | 1 + highlight/Cargo.toml | 1 + highlight/src/lib.rs | 99 +++++++++++++++++-- lib/Cargo.toml | 1 + lib/binding_rust/lib.rs | 66 +++++++++---- tags/Cargo.toml | 1 + tags/src/lib.rs | 5 +- 14 files changed, 271 insertions(+), 105 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 242ab31e..3b8ce295 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1297,6 +1297,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + [[package]] name = "strsim" version = "0.11.1" @@ -1462,6 +1468,7 @@ dependencies = [ "cc", "regex", "regex-syntax", + "streaming-iterator", "tree-sitter-language", "wasmtime-c-api-impl", ] @@ -1498,6 +1505,7 @@ dependencies = [ "serde_json", "similar", "smallbitvec", + "streaming-iterator", "tempfile", "tiny_http", "tree-sitter", @@ -1550,6 +1558,7 @@ version = "0.23.0" dependencies = [ "lazy_static", "regex", + "streaming-iterator", "thiserror", "tree-sitter", ] @@ -1585,6 +1594,7 @@ version = "0.23.0" dependencies = [ "memchr", "regex", + "streaming-iterator", "thiserror", "tree-sitter", ] diff --git a/Cargo.toml b/Cargo.toml index 89c83465..999de9ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,6 +84,7 @@ serde_derive = "1.0.210" serde_json = { version = "1.0.128", features = ["preserve_order"] } similar = "2.6.0" smallbitvec = "2.5.3" +streaming-iterator = "0.1.9" tempfile = "3.12.0" thiserror = "1.0.64" tiny_http = "0.12.0" diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 869593fa..8e7e1916 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -52,6 +52,7 @@ serde_derive.workspace = true serde_json.workspace = true similar.workspace = true smallbitvec.workspace = true +streaming-iterator.workspace = true tiny_http.workspace = true walkdir.workspace = true wasmparser.workspace = true diff --git a/cli/src/query.rs b/cli/src/query.rs index 085fb967..2d8a1013 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -8,6 +8,7 @@ use std::{ use anstyle::AnsiColor; use anyhow::{Context, Result}; +use streaming_iterator::StreamingIterator; use tree_sitter::{Language, Parser, Point, Query, QueryCursor}; use crate::{ @@ -58,10 +59,10 @@ pub fn query_files_at_paths( let start = Instant::now(); if ordered_captures { - for (mat, capture_index) in - query_cursor.captures(&query, tree.root_node(), source_code.as_slice()) - { - let capture = mat.captures[capture_index]; + let mut captures = + query_cursor.captures(&query, tree.root_node(), source_code.as_slice()); + while let Some((mat, capture_index)) = captures.next() { + let capture = mat.captures[*capture_index]; let capture_name = &query.capture_names()[capture.index as usize]; if !quiet && !should_test { writeln!( @@ -81,7 +82,9 @@ pub fn query_files_at_paths( }); } } else { - for m in query_cursor.matches(&query, tree.root_node(), source_code.as_slice()) { + let mut matches = + query_cursor.matches(&query, tree.root_node(), source_code.as_slice()); + while let Some(m) = matches.next() { if !quiet && !should_test { writeln!(&mut stdout, " pattern: {}", m.pattern_index)?; } diff --git a/cli/src/tests/helpers/query_helpers.rs b/cli/src/tests/helpers/query_helpers.rs index 9e7a6f63..da6e4769 100644 --- a/cli/src/tests/helpers/query_helpers.rs +++ b/cli/src/tests/helpers/query_helpers.rs @@ -1,6 +1,7 @@ use std::{cmp::Ordering, fmt::Write, ops::Range}; use rand::prelude::Rng; +use streaming_iterator::{IntoStreamingIterator, StreamingIterator}; use tree_sitter::{ Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryMatch, Tree, TreeCursor, }; @@ -324,39 +325,39 @@ pub fn assert_query_matches( } pub fn collect_matches<'a>( - matches: impl Iterator>, + mut matches: impl StreamingIterator>, query: &'a Query, source: &'a str, ) -> Vec<(usize, Vec<(&'a str, &'a str)>)> { - matches - .map(|m| { - ( - m.pattern_index, - format_captures(m.captures.iter().copied(), query, source), - ) - }) - .collect() + let mut result = Vec::new(); + while let Some(m) = matches.next() { + result.push(( + m.pattern_index, + format_captures(m.captures.iter().into_streaming_iter_ref(), query, source), + )); + } + result } pub fn collect_captures<'a>( - captures: impl Iterator, usize)>, + captures: impl StreamingIterator, usize)>, query: &'a Query, source: &'a str, ) -> Vec<(&'a str, &'a str)> { - format_captures(captures.map(|(m, i)| m.captures[i]), query, source) + format_captures(captures.map(|(m, i)| m.captures[*i]), query, source) } fn format_captures<'a>( - captures: impl Iterator>, + mut captures: impl StreamingIterator>, query: &'a Query, source: &'a str, ) -> Vec<(&'a str, &'a str)> { - captures - .map(|capture| { - ( - query.capture_names()[capture.index as usize], - capture.node.utf8_text(source.as_bytes()).unwrap(), - ) - }) - .collect() + let mut result = Vec::new(); + while let Some(capture) = captures.next() { + result.push(( + query.capture_names()[capture.index as usize], + capture.node.utf8_text(source.as_bytes()).unwrap(), + )); + } + result } diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 45dc3144..b0aa6b2b 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -3,6 +3,7 @@ use std::{env, fmt::Write}; use indoc::indoc; use lazy_static::lazy_static; use rand::{prelude::StdRng, SeedableRng}; +use streaming_iterator::StreamingIterator; use tree_sitter::{ CaptureQuantifier, Language, Node, Parser, Point, Query, QueryCursor, QueryError, QueryErrorKind, QueryPredicate, QueryPredicateArg, QueryProperty, @@ -2267,29 +2268,50 @@ fn test_query_matches_with_wildcard_at_root_intersecting_byte_range() { // After the first line of the class definition let offset = source.find("A:").unwrap() + 2; - let matches = cursor - .set_byte_range(offset..offset) - .matches(&query, tree.root_node(), source.as_bytes()) - .map(|mat| mat.captures[0].node.kind()) - .collect::>(); + let mut matches = Vec::new(); + let mut match_iter = cursor.set_byte_range(offset..offset).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + + while let Some(mat) = match_iter.next() { + if let Some(capture) = mat.captures.first() { + matches.push(capture.node.kind()); + } + } assert_eq!(matches, &["class_definition"]); // After the first line of the function definition let offset = source.find("b():").unwrap() + 4; - let matches = cursor - .set_byte_range(offset..offset) - .matches(&query, tree.root_node(), source.as_bytes()) - .map(|mat| mat.captures[0].node.kind()) - .collect::>(); + let mut matches = Vec::new(); + let mut match_iter = cursor.set_byte_range(offset..offset).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + + while let Some(mat) = match_iter.next() { + if let Some(capture) = mat.captures.first() { + matches.push(capture.node.kind()); + } + } assert_eq!(matches, &["class_definition", "function_definition"]); // After the first line of the if statement let offset = source.find("c:").unwrap() + 2; - let matches = cursor - .set_byte_range(offset..offset) - .matches(&query, tree.root_node(), source.as_bytes()) - .map(|mat| mat.captures[0].node.kind()) - .collect::>(); + let mut matches = Vec::new(); + let mut match_iter = cursor.set_byte_range(offset..offset).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + + while let Some(mat) = match_iter.next() { + if let Some(capture) = mat.captures.first() { + matches.push(capture.node.kind()); + } + } assert_eq!( matches, &["class_definition", "function_definition", "if_statement"] @@ -2342,8 +2364,9 @@ fn test_query_captures_within_byte_range_assigned_after_iterating() { // Retrieve some captures let mut results = Vec::new(); - for (mat, capture_ix) in captures.by_ref().take(5) { - let capture = mat.captures[capture_ix]; + let mut first_five = captures.by_ref().take(5); + while let Some((mat, capture_ix)) = first_five.next() { + let capture = mat.captures[*capture_ix]; results.push(( query.capture_names()[capture.index as usize], &source[capture.node.byte_range()], @@ -2365,8 +2388,8 @@ fn test_query_captures_within_byte_range_assigned_after_iterating() { // intersect the range. results.clear(); captures.set_byte_range(source.find("Ok").unwrap()..source.len()); - for (mat, capture_ix) in captures { - let capture = mat.captures[capture_ix]; + while let Some((mat, capture_ix)) = captures.next() { + let capture = mat.captures[*capture_ix]; results.push(( query.capture_names()[capture.index as usize], &source[capture.node.byte_range()], @@ -2602,21 +2625,23 @@ fn test_query_matches_with_captured_wildcard_at_root() { parser.set_language(&language).unwrap(); let tree = parser.parse(source, None).unwrap(); - let match_capture_names_and_rows = cursor - .matches(&query, tree.root_node(), source.as_bytes()) - .map(|m| { - m.captures - .iter() - .map(|c| { - ( - query.capture_names()[c.index as usize], - c.node.kind(), - c.node.start_position().row, - ) - }) - .collect::>() - }) - .collect::>(); + let mut match_capture_names_and_rows = Vec::new(); + let mut match_iter = cursor.matches(&query, tree.root_node(), source.as_bytes()); + + while let Some(m) = match_iter.next() { + let captures = m + .captures + .iter() + .map(|c| { + ( + query.capture_names()[c.index as usize], + c.node.kind(), + c.node.start_position().row, + ) + }) + .collect::>(); + match_capture_names_and_rows.push(captures); + } assert_eq!( match_capture_names_and_rows, @@ -3460,9 +3485,13 @@ fn test_query_captures_with_matches_removed() { let mut cursor = QueryCursor::new(); let mut captured_strings = Vec::new(); - for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) { - let capture = m.captures[i]; + + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + while let Some((m, i)) = captures.next() { + println!("captured: {:?}, {}", m, i); + let capture = m.captures[*i]; let text = capture.node.utf8_text(source.as_bytes()).unwrap(); + println!("captured: {:?}", text); if text == "a" { m.remove(); continue; @@ -3504,8 +3533,9 @@ fn test_query_captures_with_matches_removed_before_they_finish() { let mut cursor = QueryCursor::new(); let mut captured_strings = Vec::new(); - for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) { - let capture = m.captures[i]; + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + while let Some((m, i)) = captures.next() { + let capture = m.captures[*i]; let text = capture.node.utf8_text(source.as_bytes()).unwrap(); if text == "as" { m.remove(); @@ -3912,21 +3942,24 @@ fn test_query_random() { panic!("failed to build query for pattern {pattern} - {e}. seed: {seed}"); } }; - let mut actual_matches = cursor - .matches( - &query, - test_tree.root_node(), - include_bytes!("parser_test.rs").as_ref(), - ) - .map(|mat| Match { + let mut actual_matches = Vec::new(); + let mut match_iter = cursor.matches( + &query, + test_tree.root_node(), + include_bytes!("parser_test.rs").as_ref(), + ); + + while let Some(mat) = match_iter.next() { + let transformed_match = Match { last_node: None, captures: mat .captures .iter() .map(|c| (query.capture_names()[c.index as usize], c.node)) .collect::>(), - }) - .collect::>(); + }; + actual_matches.push(transformed_match); + } // actual_matches.sort_unstable(); actual_matches.dedup(); @@ -4908,12 +4941,12 @@ fn test_consecutive_zero_or_modifiers() { assert!(matches.next().is_some()); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, three_tree.root_node(), three_source.as_bytes()); + let mut matches = cursor.matches(&query, three_tree.root_node(), three_source.as_bytes()); let mut len_3 = false; let mut len_1 = false; - for m in matches { + while let Some(m) = matches.next() { if m.captures.len() == 3 { len_3 = true; } diff --git a/cli/src/tests/text_provider_test.rs b/cli/src/tests/text_provider_test.rs index e35e20ec..7c1d538c 100644 --- a/cli/src/tests/text_provider_test.rs +++ b/cli/src/tests/text_provider_test.rs @@ -1,5 +1,6 @@ use std::{iter, sync::Arc}; +use streaming_iterator::StreamingIterator; use tree_sitter::{Language, Node, Parser, Point, Query, QueryCursor, TextProvider, Tree}; use crate::tests::helpers::fixtures::get_language; @@ -30,8 +31,8 @@ fn tree_query>(tree: &Tree, text: impl TextProvider, language: let mut cursor = QueryCursor::new(); let mut captures = cursor.captures(&query, tree.root_node(), text); let (match_, idx) = captures.next().unwrap(); - let capture = match_.captures[idx]; - assert_eq!(capture.index as usize, idx); + let capture = match_.captures[*idx]; + assert_eq!(capture.index as usize, *idx); assert_eq!("comment", capture.node.kind()); } diff --git a/cli/src/tests/wasm_language_test.rs b/cli/src/tests/wasm_language_test.rs index 1ea63658..34584dae 100644 --- a/cli/src/tests/wasm_language_test.rs +++ b/cli/src/tests/wasm_language_test.rs @@ -1,6 +1,7 @@ use std::fs; use lazy_static::lazy_static; +use streaming_iterator::StreamingIterator; use tree_sitter::{ wasmtime::Engine, Parser, Query, QueryCursor, WasmError, WasmErrorKind, WasmStore, }; diff --git a/highlight/Cargo.toml b/highlight/Cargo.toml index 694f5064..1b57fd52 100644 --- a/highlight/Cargo.toml +++ b/highlight/Cargo.toml @@ -22,5 +22,6 @@ crate-type = ["lib", "staticlib"] lazy_static.workspace = true regex.workspace = true thiserror.workspace = true +streaming-iterator.workspace = true tree-sitter.workspace = true diff --git a/highlight/src/lib.rs b/highlight/src/lib.rs index 74684669..b4e38d58 100644 --- a/highlight/src/lib.rs +++ b/highlight/src/lib.rs @@ -1,18 +1,23 @@ #![doc = include_str!("../README.md")] pub mod c_lib; +use core::slice; use std::{ collections::HashSet, - iter, mem, ops, str, + iter, + marker::PhantomData, + mem::{self, MaybeUninit}, + ops, str, sync::atomic::{AtomicUsize, Ordering}, }; pub use c_lib as c; use lazy_static::lazy_static; +use streaming_iterator::StreamingIterator; use thiserror::Error; use tree_sitter::{ - Language, LossyUtf8, Node, Parser, Point, Query, QueryCaptures, QueryCursor, QueryError, - QueryMatch, Range, Tree, + ffi, Language, LossyUtf8, Node, Parser, Point, Query, QueryCapture, QueryCaptures, QueryCursor, + QueryError, QueryMatch, Range, TextProvider, Tree, }; const CANCELLATION_CHECK_INTERVAL: usize = 100; @@ -171,7 +176,7 @@ where struct HighlightIterLayer<'a> { _tree: Tree, cursor: QueryCursor, - captures: iter::Peekable>, + captures: iter::Peekable<_QueryCaptures<'a, 'a, &'a [u8], &'a [u8]>>, config: &'a HighlightConfiguration, highlight_end_stack: Vec, scope_stack: Vec>, @@ -179,6 +184,77 @@ struct HighlightIterLayer<'a> { depth: usize, } +pub struct _QueryCaptures<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> { + ptr: *mut ffi::TSQueryCursor, + query: &'query Query, + text_provider: T, + buffer1: Vec, + buffer2: Vec, + _current_match: Option<(QueryMatch<'query, 'tree>, usize)>, + _phantom: PhantomData<(&'tree (), I)>, +} + +struct _QueryMatch<'cursor, 'tree> { + pub _pattern_index: usize, + pub _captures: &'cursor [QueryCapture<'tree>], + _id: u32, + _cursor: *mut ffi::TSQueryCursor, +} + +impl<'tree> _QueryMatch<'_, 'tree> { + fn new(m: &ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self { + _QueryMatch { + _cursor: cursor, + _id: m.id, + _pattern_index: m.pattern_index as usize, + _captures: (m.capture_count > 0) + .then(|| unsafe { + slice::from_raw_parts( + m.captures.cast::>(), + m.capture_count as usize, + ) + }) + .unwrap_or_default(), + } + } +} + +impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> Iterator + for _QueryCaptures<'query, 'tree, T, I> +{ + type Item = (QueryMatch<'query, 'tree>, usize); + + fn next(&mut self) -> Option { + unsafe { + loop { + let mut capture_index = 0u32; + let mut m = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_capture( + self.ptr, + m.as_mut_ptr(), + core::ptr::addr_of_mut!(capture_index), + ) { + let result = std::mem::transmute::<_QueryMatch, QueryMatch>(_QueryMatch::new( + &m.assume_init(), + self.ptr, + )); + if result.satisfies_text_predicates( + self.query, + &mut self.buffer1, + &mut self.buffer2, + &mut self.text_provider, + ) { + return Some((result, capture_index as usize)); + } + result.remove(); + } else { + return None; + } + } + } + } +} + impl Default for Highlighter { fn default() -> Self { Self::new() @@ -456,15 +532,15 @@ impl<'a> HighlightIterLayer<'a> { if let Some(combined_injections_query) = &config.combined_injections_query { let mut injections_by_pattern_index = vec![(None, Vec::new(), false); combined_injections_query.pattern_count()]; - let matches = + let mut matches = cursor.matches(combined_injections_query, tree.root_node(), source); - for mat in matches { + while let Some(mat) = matches.next() { let entry = &mut injections_by_pattern_index[mat.pattern_index]; let (language_name, content_node, include_children) = injection_for_match( config, parent_name, combined_injections_query, - &mat, + mat, source, ); if language_name.is_some() { @@ -499,9 +575,12 @@ impl<'a> HighlightIterLayer<'a> { let cursor_ref = unsafe { mem::transmute::<&mut QueryCursor, &'static mut QueryCursor>(&mut cursor) }; - let captures = cursor_ref - .captures(&config.query, tree_ref.root_node(), source) - .peekable(); + let captures = unsafe { + std::mem::transmute::, _QueryCaptures<_, _>>( + cursor_ref.captures(&config.query, tree_ref.root_node(), source), + ) + } + .peekable(); result.push(HighlightIterLayer { highlight_end_stack: Vec::new(), diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 10ec7540..c261e36c 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -34,6 +34,7 @@ wasm = ["wasmtime-c-api"] regex = { version = "1.10.6", default-features = false, features = ["unicode"] } regex-syntax = { version = "0.8.4", default-features = false } tree-sitter-language = { version = "0.1", path = "language" } +streaming-iterator = "0.1.9" [dependencies.wasmtime-c-api] version = "25.0.1" diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 43101596..32543a03 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -27,6 +27,7 @@ use std::os::fd::AsRawFd; #[cfg(all(windows, feature = "std"))] use std::os::windows::io::AsRawHandle; +use streaming_iterator::{StreamingIterator, StreamingIteratorMut}; use tree_sitter_language::LanguageFn; #[cfg(feature = "wasm")] @@ -201,23 +202,25 @@ pub struct QueryMatch<'cursor, 'tree> { } /// A sequence of [`QueryMatch`]es associated with a given [`QueryCursor`]. -pub struct QueryMatches<'query, 'cursor, T: TextProvider, I: AsRef<[u8]>> { +pub struct QueryMatches<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> { ptr: *mut ffi::TSQueryCursor, query: &'query Query, text_provider: T, buffer1: Vec, buffer2: Vec, - _phantom: PhantomData<(&'cursor (), I)>, + current_match: Option>, + _phantom: PhantomData<(&'tree (), I)>, } /// A sequence of [`QueryCapture`]s associated with a given [`QueryCursor`]. -pub struct QueryCaptures<'query, 'cursor, T: TextProvider, I: AsRef<[u8]>> { +pub struct QueryCaptures<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> { ptr: *mut ffi::TSQueryCursor, query: &'query Query, text_provider: T, buffer1: Vec, buffer2: Vec, - _phantom: PhantomData<(&'cursor (), I)>, + current_match: Option<(QueryMatch<'query, 'tree>, usize)>, + _phantom: PhantomData<(&'tree (), I)>, } pub trait TextProvider @@ -2433,6 +2436,7 @@ impl QueryCursor { text_provider, buffer1: Vec::default(), buffer2: Vec::default(), + current_match: None, _phantom: PhantomData, } } @@ -2457,6 +2461,7 @@ impl QueryCursor { text_provider, buffer1: Vec::default(), buffer2: Vec::default(), + current_match: None, _phantom: PhantomData, } } @@ -2522,7 +2527,7 @@ impl<'tree> QueryMatch<'_, 'tree> { } #[doc(alias = "ts_query_cursor_remove_match")] - pub fn remove(self) { + pub fn remove(&self) { unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) } } @@ -2551,7 +2556,7 @@ impl<'tree> QueryMatch<'_, 'tree> { } } - fn satisfies_text_predicates>( + pub fn satisfies_text_predicates>( &self, query: &Query, buffer1: &mut Vec, @@ -2669,13 +2674,16 @@ impl QueryProperty { } } -impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> Iterator +/// Provide StreamingIterator instead of traditional one as the underlying object in the C library +/// gets updated on each iteration. Created copies would have their internal state overwritten, +/// leading to Undefined Behavior +impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> StreamingIterator for QueryMatches<'query, 'tree, T, I> { type Item = QueryMatch<'query, 'tree>; - fn next(&mut self) -> Option { - unsafe { + fn advance(&mut self) { + self.current_match = unsafe { loop { let mut m = MaybeUninit::::uninit(); if ffi::ts_query_cursor_next_match(self.ptr, m.as_mut_ptr()) { @@ -2686,23 +2694,35 @@ impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> Iterator &mut self.buffer2, &mut self.text_provider, ) { - return Some(result); + break Some(result); } } else { - return None; + break None; } } - } + }; + } + + fn get(&self) -> Option<&Self::Item> { + self.current_match.as_ref() } } -impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> Iterator +impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> StreamingIteratorMut + for QueryMatches<'query, 'tree, T, I> +{ + fn get_mut(&mut self) -> Option<&mut Self::Item> { + self.current_match.as_mut() + } +} + +impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> StreamingIterator for QueryCaptures<'query, 'tree, T, I> { type Item = (QueryMatch<'query, 'tree>, usize); - fn next(&mut self) -> Option { - unsafe { + fn advance(&mut self) { + self.current_match = unsafe { loop { let mut capture_index = 0u32; let mut m = MaybeUninit::::uninit(); @@ -2718,15 +2738,27 @@ impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> Iterator &mut self.buffer2, &mut self.text_provider, ) { - return Some((result, capture_index as usize)); + break Some((result, capture_index as usize)); } result.remove(); } else { - return None; + break None; } } } } + + fn get(&self) -> Option<&Self::Item> { + self.current_match.as_ref() + } +} + +impl<'query, 'tree: 'query, T: TextProvider, I: AsRef<[u8]>> StreamingIteratorMut + for QueryCaptures<'query, 'tree, T, I> +{ + fn get_mut(&mut self) -> Option<&mut Self::Item> { + self.current_match.as_mut() + } } impl, I: AsRef<[u8]>> QueryMatches<'_, '_, T, I> { diff --git a/tags/Cargo.toml b/tags/Cargo.toml index 65cf9251..b7d0846c 100644 --- a/tags/Cargo.toml +++ b/tags/Cargo.toml @@ -21,6 +21,7 @@ crate-type = ["lib", "staticlib"] [dependencies] memchr.workspace = true regex.workspace = true +streaming-iterator.workspace = true thiserror.workspace = true tree-sitter.workspace = true diff --git a/tags/src/lib.rs b/tags/src/lib.rs index 0bc27d20..c7ba2df3 100644 --- a/tags/src/lib.rs +++ b/tags/src/lib.rs @@ -15,6 +15,7 @@ use std::{ use memchr::memchr; use regex::Regex; +use streaming_iterator::StreamingIterator; use thiserror::Error; use tree_sitter::{ Language, LossyUtf8, Parser, Point, Query, QueryCursor, QueryError, QueryPredicateArg, Tree, @@ -100,7 +101,7 @@ struct LocalScope<'a> { struct TagsIter<'a, I> where - I: Iterator>, + I: StreamingIterator>, { matches: I, _tree: Tree, @@ -316,7 +317,7 @@ impl TagsContext { impl<'a, I> Iterator for TagsIter<'a, I> where - I: Iterator>, + I: StreamingIterator>, { type Item = Result;