From 937dcf5fd139badd37cad5de4dfe1123040c36b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20M=C3=BCller?= Date: Mon, 16 Jun 2025 10:31:25 -0700 Subject: [PATCH] feat(rust)!: use ops::ControlFlow as parse and query progress return value Instead of returning an undocumented boolean flag, use a core::ops::ControlFlow object. At the expense of being a bit more verbose, this is a type that should be self-explanatory in the context of a callback, as an indication of whether to continue processing or stop. --- crates/cli/src/parse.rs | 7 ++- crates/cli/src/tests/parser_test.rs | 85 +++++++++++++++++++++-------- crates/cli/src/tests/query_test.rs | 11 +++- crates/highlight/src/highlight.rs | 11 +++- crates/tags/src/tags.rs | 10 +++- lib/binding_rust/lib.rs | 43 +++++++++++---- 6 files changed, 120 insertions(+), 47 deletions(-) diff --git a/crates/cli/src/parse.rs b/crates/cli/src/parse.rs index d5285011..ed0ac9f6 100644 --- a/crates/cli/src/parse.rs +++ b/crates/cli/src/parse.rs @@ -1,6 +1,7 @@ use std::{ fmt, fs, io::{self, Write}, + ops::ControlFlow, path::{Path, PathBuf}, sync::atomic::{AtomicUsize, Ordering}, time::{Duration, Instant}, @@ -357,15 +358,15 @@ pub fn parse_file_at_path( let progress_callback = &mut |_: &ParseState| { if let Some(cancellation_flag) = opts.cancellation_flag { if cancellation_flag.load(Ordering::SeqCst) != 0 { - return true; + return ControlFlow::Break(()); } } if opts.timeout > 0 && start_time.elapsed().as_micros() > opts.timeout as u128 { - return true; + return ControlFlow::Break(()); } - false + ControlFlow::Continue(()) }; let parse_opts = ParseOptions::new().progress_callback(progress_callback); diff --git a/crates/cli/src/tests/parser_test.rs b/crates/cli/src/tests/parser_test.rs index d8b9767d..ba897163 100644 --- a/crates/cli/src/tests/parser_test.rs +++ b/crates/cli/src/tests/parser_test.rs @@ -1,4 +1,5 @@ use std::{ + ops::ControlFlow, sync::atomic::{AtomicUsize, Ordering}, thread, time, }; @@ -699,7 +700,13 @@ fn test_parsing_on_multiple_threads() { fn test_parsing_cancelled_by_another_thread() { let cancellation_flag = std::sync::Arc::new(AtomicUsize::new(0)); let flag = cancellation_flag.clone(); - let callback = &mut |_: &ParseState| cancellation_flag.load(Ordering::SeqCst) != 0; + let callback = &mut |_: &ParseState| { + if cancellation_flag.load(Ordering::SeqCst) != 0 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + }; let mut parser = Parser::new(); parser.set_language(&get_language("javascript")).unwrap(); @@ -764,9 +771,13 @@ fn test_parsing_with_a_timeout() { } }, None, - Some( - ParseOptions::new().progress_callback(&mut |_| start_time.elapsed().as_micros() > 1000), - ), + Some(ParseOptions::new().progress_callback(&mut |_| { + if start_time.elapsed().as_micros() > 1000 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + })), ); assert!(tree.is_none()); assert!(start_time.elapsed().as_micros() < 2000); @@ -782,9 +793,13 @@ fn test_parsing_with_a_timeout() { } }, None, - Some( - ParseOptions::new().progress_callback(&mut |_| start_time.elapsed().as_micros() > 5000), - ), + Some(ParseOptions::new().progress_callback(&mut |_| { + if start_time.elapsed().as_micros() > 5000 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + })), ); assert!(tree.is_none()); assert!(start_time.elapsed().as_micros() > 100); @@ -822,7 +837,13 @@ fn test_parsing_with_a_timeout_and_a_reset() { } }, None, - Some(ParseOptions::new().progress_callback(&mut |_| start_time.elapsed().as_micros() > 5)), + Some(ParseOptions::new().progress_callback(&mut |_| { + if start_time.elapsed().as_micros() > 5 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + })), ); assert!(tree.is_none()); @@ -853,7 +874,13 @@ fn test_parsing_with_a_timeout_and_a_reset() { } }, None, - Some(ParseOptions::new().progress_callback(&mut |_| start_time.elapsed().as_micros() > 5)), + Some(ParseOptions::new().progress_callback(&mut |_| { + if start_time.elapsed().as_micros() > 5 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + })), ); assert!(tree.is_none()); @@ -893,10 +920,13 @@ fn test_parsing_with_a_timeout_and_implicit_reset() { } }, None, - Some( - ParseOptions::new() - .progress_callback(&mut |_| start_time.elapsed().as_micros() > 5), - ), + Some(ParseOptions::new().progress_callback(&mut |_| { + if start_time.elapsed().as_micros() > 5 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + })), ); assert!(tree.is_none()); @@ -937,10 +967,13 @@ fn test_parsing_with_timeout_and_no_completion() { } }, None, - Some( - ParseOptions::new() - .progress_callback(&mut |_| start_time.elapsed().as_micros() > 5), - ), + Some(ParseOptions::new().progress_callback(&mut |_| { + if start_time.elapsed().as_micros() > 5 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + })), ); assert!(tree.is_none()); @@ -979,10 +1012,10 @@ fn test_parsing_with_timeout_during_balancing() { // are in the balancing phase. if state.current_byte_offset() != current_byte_offset { current_byte_offset = state.current_byte_offset(); - false + ControlFlow::Continue(()) } else { in_balancing = true; - true + ControlFlow::Break(()) } })), ); @@ -1004,10 +1037,10 @@ fn test_parsing_with_timeout_during_balancing() { Some(ParseOptions::new().progress_callback(&mut |state| { if state.current_byte_offset() != current_byte_offset { current_byte_offset = state.current_byte_offset(); - false + ControlFlow::Continue(()) } else { in_balancing = true; - true + ControlFlow::Break(()) } })), ); @@ -1031,7 +1064,7 @@ fn test_parsing_with_timeout_during_balancing() { // Because we've already finished parsing, we should only be resuming the // balancing phase. assert!(state.current_byte_offset() == current_byte_offset); - false + ControlFlow::Continue(()) })), ) .unwrap(); @@ -1057,7 +1090,11 @@ fn test_parsing_with_timeout_when_error_detected() { None, Some(ParseOptions::new().progress_callback(&mut |state| { offset = state.current_byte_offset(); - state.has_error() + if state.has_error() { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } })), ); @@ -1737,7 +1774,7 @@ fn test_parsing_by_halting_at_offset() { None, Some(ParseOptions::new().progress_callback(&mut |p| { seen_byte_offsets.push(p.current_byte_offset()); - false + ControlFlow::Continue(()) })), ) .unwrap(); diff --git a/crates/cli/src/tests/query_test.rs b/crates/cli/src/tests/query_test.rs index ab025e04..835138ce 100644 --- a/crates/cli/src/tests/query_test.rs +++ b/crates/cli/src/tests/query_test.rs @@ -1,4 +1,4 @@ -use std::{env, fmt::Write, sync::LazyLock}; +use std::{env, fmt::Write, ops::ControlFlow, sync::LazyLock}; use indoc::indoc; use rand::{prelude::StdRng, SeedableRng}; @@ -5446,8 +5446,13 @@ fn test_query_execution_with_timeout() { &query, tree.root_node(), source_code.as_bytes(), - QueryCursorOptions::new() - .progress_callback(&mut |_| start_time.elapsed().as_micros() > 1000), + QueryCursorOptions::new().progress_callback(&mut |_| { + if start_time.elapsed().as_micros() > 1000 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + }), ) .count(); assert!(matches < 1000); diff --git a/crates/highlight/src/highlight.rs b/crates/highlight/src/highlight.rs index e4555fa0..bb81fc08 100644 --- a/crates/highlight/src/highlight.rs +++ b/crates/highlight/src/highlight.rs @@ -7,7 +7,8 @@ use std::{ iter, marker::PhantomData, mem::{self, MaybeUninit}, - ops, str, + ops::{self, ControlFlow}, + str, sync::{ atomic::{AtomicUsize, Ordering}, LazyLock, @@ -538,9 +539,13 @@ impl<'a> HighlightIterLayer<'a> { None, Some(ParseOptions::new().progress_callback(&mut |_| { if let Some(cancellation_flag) = cancellation_flag { - cancellation_flag.load(Ordering::SeqCst) != 0 + if cancellation_flag.load(Ordering::SeqCst) != 0 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } } else { - false + ControlFlow::Continue(()) } })), ) diff --git a/crates/tags/src/tags.rs b/crates/tags/src/tags.rs index 57179e9a..16270b0a 100644 --- a/crates/tags/src/tags.rs +++ b/crates/tags/src/tags.rs @@ -7,7 +7,7 @@ use std::{ collections::HashMap, ffi::{CStr, CString}, mem, - ops::Range, + ops::{ControlFlow, Range}, os::raw::c_char, str, sync::atomic::{AtomicUsize, Ordering}, @@ -301,9 +301,13 @@ impl TagsContext { None, Some(ParseOptions::new().progress_callback(&mut |_| { if let Some(cancellation_flag) = cancellation_flag { - cancellation_flag.load(Ordering::SeqCst) != 0 + if cancellation_flag.load(Ordering::SeqCst) != 0 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } } else { - false + ControlFlow::Continue(()) } })), ) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index cb4b34c0..e3b2d950 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -16,7 +16,7 @@ use core::{ marker::PhantomData, mem::MaybeUninit, num::NonZeroU16, - ops::{self, Deref}, + ops::{self, ControlFlow, Deref}, ptr::{self, NonNull}, slice, str, sync::atomic::AtomicUsize, @@ -177,7 +177,10 @@ impl<'a> ParseOptions<'a> { } #[must_use] - pub fn progress_callback bool>(mut self, callback: &'a mut F) -> Self { + pub fn progress_callback ControlFlow<()>>( + mut self, + callback: &'a mut F, + ) -> Self { self.progress_callback = Some(callback); self } @@ -195,7 +198,7 @@ impl<'a> QueryCursorOptions<'a> { } #[must_use] - pub fn progress_callback bool>( + pub fn progress_callback ControlFlow<()>>( mut self, callback: &'a mut F, ) -> Self { @@ -232,10 +235,10 @@ type FieldId = NonZeroU16; type Logger<'a> = Box; /// A callback that receives the parse state during parsing. -type ParseProgressCallback<'a> = &'a mut dyn FnMut(&ParseState) -> bool; +type ParseProgressCallback<'a> = &'a mut dyn FnMut(&ParseState) -> ControlFlow<()>; /// A callback that receives the query state during query execution. -type QueryProgressCallback<'a> = &'a mut dyn FnMut(&QueryCursorState) -> bool; +type QueryProgressCallback<'a> = &'a mut dyn FnMut(&QueryCursorState) -> ControlFlow<()>; pub trait Decode { /// A callback that decodes the next code point from the input slice. It should return the code @@ -869,7 +872,10 @@ impl Parser { .cast::() .as_mut() .unwrap(); - callback(&ParseState::from_raw(state)) + match callback(&ParseState::from_raw(state)) { + ControlFlow::Continue(()) => false, + ControlFlow::Break(()) => true, + } } // This C function is passed to Tree-sitter as the input callback. @@ -1001,7 +1007,10 @@ impl Parser { .cast::() .as_mut() .unwrap(); - callback(&ParseState::from_raw(state)) + match callback(&ParseState::from_raw(state)) { + ControlFlow::Continue(()) => false, + ControlFlow::Break(()) => true, + } } // This C function is passed to Tree-sitter as the input callback. @@ -1118,7 +1127,10 @@ impl Parser { .cast::() .as_mut() .unwrap(); - callback(&ParseState::from_raw(state)) + match callback(&ParseState::from_raw(state)) { + ControlFlow::Continue(()) => false, + ControlFlow::Break(()) => true, + } } // This C function is passed to Tree-sitter as the input callback. @@ -1218,7 +1230,10 @@ impl Parser { .cast::() .as_mut() .unwrap(); - callback(&ParseState::from_raw(state)) + match callback(&ParseState::from_raw(state)) { + ControlFlow::Continue(()) => false, + ControlFlow::Break(()) => true, + } } // At compile time, create a C-compatible callback that calls the custom `decode` method. @@ -3103,7 +3118,10 @@ impl QueryCursor { .cast::() .as_mut() .unwrap(); - (callback)(&QueryCursorState::from_raw(state)) + match callback(&QueryCursorState::from_raw(state)) { + ControlFlow::Continue(()) => false, + ControlFlow::Break(()) => true, + } } let query_options = options.progress_callback.map(|cb| { @@ -3189,7 +3207,10 @@ impl QueryCursor { .cast::() .as_mut() .unwrap(); - (callback)(&QueryCursorState::from_raw(state)) + match callback(&QueryCursorState::from_raw(state)) { + ControlFlow::Continue(()) => false, + ControlFlow::Break(()) => true, + } } let query_options = options.progress_callback.map(|cb| {