From 070f11b8bfbe693322dfc5f9a60d6dd68fdef28d Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 23 Sep 2019 16:55:28 -0700 Subject: [PATCH] Use ptr::NonNull in Rust bindings --- cli/src/tests/query_test.rs | 12 +-- lib/binding_rust/lib.rs | 147 ++++++++++++++++++------------------ 2 files changed, 81 insertions(+), 78 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 51eea4b2..11abc028 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -56,23 +56,23 @@ fn test_query_errors_on_invalid_symbols() { assert_eq!( Query::new(language, "(clas)"), - Err(QueryError::NodeType("clas")) + Err(QueryError::NodeType("clas".to_string())) ); assert_eq!( Query::new(language, "(if_statement (arrayyyyy))"), - Err(QueryError::NodeType("arrayyyyy")) + Err(QueryError::NodeType("arrayyyyy".to_string())) ); assert_eq!( Query::new(language, "(if_statement condition: (non_existent3))"), - Err(QueryError::NodeType("non_existent3")) + Err(QueryError::NodeType("non_existent3".to_string())) ); assert_eq!( Query::new(language, "(if_statement condit: (identifier))"), - Err(QueryError::Field("condit")) + Err(QueryError::Field("condit".to_string())) ); assert_eq!( Query::new(language, "(if_statement conditioning: (identifier))"), - Err(QueryError::Field("conditioning")) + Err(QueryError::Field("conditioning".to_string())) ); }); } @@ -96,7 +96,7 @@ fn test_query_errors_on_invalid_conditions() { ); assert_eq!( Query::new(language, "((identifier) @id (eq? @id @ok))"), - Err(QueryError::Capture("ok")) + Err(QueryError::Capture("ok".to_string())) ); }); } diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 2069e373..6a7fe88f 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -17,6 +17,7 @@ use std::ffi::CStr; use std::marker::PhantomData; use std::mem::MaybeUninit; use std::os::raw::{c_char, c_void}; +use std::ptr::NonNull; use std::sync::atomic::AtomicUsize; use std::{char, fmt, ptr, slice, str, u16}; @@ -123,9 +124,9 @@ pub struct PropertySheetJSON

{ #[derive(Clone, Copy)] pub struct Node<'a>(ffi::TSNode, PhantomData<&'a ()>); -pub struct Parser(*mut ffi::TSParser); +pub struct Parser(NonNull); -pub struct Tree(*mut ffi::TSTree); +pub struct Tree(NonNull); pub struct TreeCursor<'a>(ffi::TSTreeCursor, PhantomData<&'a ()>); @@ -146,29 +147,30 @@ enum QueryPredicate { #[derive(Debug)] pub struct Query { - ptr: *mut ffi::TSQuery, + ptr: NonNull, capture_names: Vec, predicates: Vec>, } -pub struct QueryCursor(*mut ffi::TSQueryCursor); +pub struct QueryCursor(NonNull); pub struct QueryMatch<'a> { pub pattern_index: usize, captures: &'a [ffi::TSQueryCapture], } +#[derive(Clone)] pub struct QueryCapture<'a> { pub index: usize, pub node: Node<'a>, } #[derive(Debug, PartialEq, Eq)] -pub enum QueryError<'a> { +pub enum QueryError { Syntax(usize), - NodeType(&'a str), - Field(&'a str), - Capture(&'a str), + NodeType(String), + Field(String), + Capture(String), Predicate(String), } @@ -230,15 +232,11 @@ impl fmt::Display for LanguageError { } } -unsafe impl Send for Language {} - -unsafe impl Sync for Language {} - impl Parser { pub fn new() -> Parser { unsafe { let parser = ffi::ts_parser_new(); - Parser(parser) + Parser(NonNull::new_unchecked(parser)) } } @@ -250,14 +248,14 @@ impl Parser { Err(LanguageError { version }) } else { unsafe { - ffi::ts_parser_set_language(self.0, language.0); + ffi::ts_parser_set_language(self.0.as_ptr(), language.0); } Ok(()) } } pub fn language(&self) -> Option { - let ptr = unsafe { ffi::ts_parser_language(self.0) }; + let ptr = unsafe { ffi::ts_parser_language(self.0.as_ptr()) }; if ptr.is_null() { None } else { @@ -266,12 +264,12 @@ impl Parser { } pub fn logger(&self) -> Option<&Logger> { - let logger = unsafe { ffi::ts_parser_logger(self.0) }; + let logger = unsafe { ffi::ts_parser_logger(self.0.as_ptr()) }; unsafe { (logger.payload as *mut Logger).as_ref() } } pub fn set_logger(&mut self, logger: Option) { - let prev_logger = unsafe { ffi::ts_parser_logger(self.0) }; + let prev_logger = unsafe { ffi::ts_parser_logger(self.0.as_ptr()) }; if !prev_logger.payload.is_null() { drop(unsafe { Box::from_raw(prev_logger.payload as *mut Logger) }); } @@ -309,17 +307,17 @@ impl Parser { }; } - unsafe { ffi::ts_parser_set_logger(self.0, c_logger) }; + unsafe { ffi::ts_parser_set_logger(self.0.as_ptr(), c_logger) }; } #[cfg(unix)] pub fn print_dot_graphs(&mut self, file: &impl AsRawFd) { let fd = file.as_raw_fd(); - unsafe { ffi::ts_parser_print_dot_graphs(self.0, ffi::dup(fd)) } + unsafe { ffi::ts_parser_print_dot_graphs(self.0.as_ptr(), ffi::dup(fd)) } } pub fn stop_printing_dot_graphs(&mut self) { - unsafe { ffi::ts_parser_print_dot_graphs(self.0, -1) } + unsafe { ffi::ts_parser_print_dot_graphs(self.0.as_ptr(), -1) } } /// Parse a slice of UTF8 text. @@ -408,12 +406,10 @@ impl Parser { encoding: ffi::TSInputEncoding_TSInputEncodingUTF8, }; - let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0); - let c_new_tree = unsafe { ffi::ts_parser_parse(self.0, c_old_tree, c_input) }; - if c_new_tree.is_null() { - None - } else { - Some(Tree(c_new_tree)) + let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr()); + unsafe { + let c_new_tree = ffi::ts_parser_parse(self.0.as_ptr(), c_old_tree, c_input); + NonNull::new(c_new_tree).map(Tree) } } @@ -466,47 +462,49 @@ impl Parser { encoding: ffi::TSInputEncoding_TSInputEncodingUTF16, }; - let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0); - let c_new_tree = unsafe { ffi::ts_parser_parse(self.0, c_old_tree, c_input) }; - if c_new_tree.is_null() { - None - } else { - Some(Tree(c_new_tree)) + let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr()); + unsafe { + let c_new_tree = ffi::ts_parser_parse(self.0.as_ptr(), c_old_tree, c_input); + NonNull::new(c_new_tree).map(Tree) } } pub fn reset(&mut self) { - unsafe { ffi::ts_parser_reset(self.0) } + unsafe { ffi::ts_parser_reset(self.0.as_ptr()) } } pub fn timeout_micros(&self) -> u64 { - unsafe { ffi::ts_parser_timeout_micros(self.0) } + unsafe { ffi::ts_parser_timeout_micros(self.0.as_ptr()) } } pub fn set_timeout_micros(&mut self, timeout_micros: u64) { - unsafe { ffi::ts_parser_set_timeout_micros(self.0, timeout_micros) } + unsafe { ffi::ts_parser_set_timeout_micros(self.0.as_ptr(), timeout_micros) } } pub fn set_included_ranges(&mut self, ranges: &[Range]) { let ts_ranges: Vec = ranges.iter().cloned().map(|range| range.into()).collect(); unsafe { - ffi::ts_parser_set_included_ranges(self.0, ts_ranges.as_ptr(), ts_ranges.len() as u32) + ffi::ts_parser_set_included_ranges( + self.0.as_ptr(), + ts_ranges.as_ptr(), + ts_ranges.len() as u32, + ) }; } pub unsafe fn cancellation_flag(&self) -> Option<&AtomicUsize> { - (ffi::ts_parser_cancellation_flag(self.0) as *const AtomicUsize).as_ref() + (ffi::ts_parser_cancellation_flag(self.0.as_ptr()) as *const AtomicUsize).as_ref() } pub unsafe fn set_cancellation_flag(&self, flag: Option<&AtomicUsize>) { if let Some(flag) = flag { ffi::ts_parser_set_cancellation_flag( - self.0, + self.0.as_ptr(), flag as *const AtomicUsize as *const usize, ); } else { - ffi::ts_parser_set_cancellation_flag(self.0, ptr::null()); + ffi::ts_parser_set_cancellation_flag(self.0.as_ptr(), ptr::null()); } } } @@ -515,24 +513,22 @@ impl Drop for Parser { fn drop(&mut self) { self.stop_printing_dot_graphs(); self.set_logger(None); - unsafe { ffi::ts_parser_delete(self.0) } + unsafe { ffi::ts_parser_delete(self.0.as_ptr()) } } } -unsafe impl Send for Parser {} - impl Tree { pub fn root_node(&self) -> Node { - Node::new(unsafe { ffi::ts_tree_root_node(self.0) }).unwrap() + Node::new(unsafe { ffi::ts_tree_root_node(self.0.as_ptr()) }).unwrap() } pub fn language(&self) -> Language { - Language(unsafe { ffi::ts_tree_language(self.0) }) + Language(unsafe { ffi::ts_tree_language(self.0.as_ptr()) }) } pub fn edit(&mut self, edit: &InputEdit) { let edit = edit.into(); - unsafe { ffi::ts_tree_edit(self.0, &edit) }; + unsafe { ffi::ts_tree_edit(self.0.as_ptr(), &edit) }; } pub fn walk(&self) -> TreeCursor { @@ -550,15 +546,16 @@ impl Tree { pub fn changed_ranges(&self, other: &Tree) -> impl ExactSizeIterator { let mut count = 0; unsafe { - let ptr = - ffi::ts_tree_get_changed_ranges(self.0, other.0, &mut count as *mut _ as *mut u32); + let ptr = ffi::ts_tree_get_changed_ranges( + self.0.as_ptr(), + other.0.as_ptr(), + &mut count as *mut _ as *mut u32, + ); util::CBufferIter::new(ptr, count).map(|r| r.into()) } } } -unsafe impl Send for Tree {} - impl fmt::Debug for Tree { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { write!(f, "{{Tree {:?}}}", self.root_node()) @@ -567,13 +564,13 @@ impl fmt::Debug for Tree { impl Drop for Tree { fn drop(&mut self) { - unsafe { ffi::ts_tree_delete(self.0) } + unsafe { ffi::ts_tree_delete(self.0.as_ptr()) } } } impl Clone for Tree { fn clone(&self) -> Tree { - unsafe { Tree(ffi::ts_tree_copy(self.0)) } + unsafe { Tree(NonNull::new_unchecked(ffi::ts_tree_copy(self.0.as_ptr()))) } } } @@ -962,7 +959,7 @@ impl<'a, P> TreePropertyCursor<'a, P> { } impl Query { - pub fn new<'a>(language: Language, source: &'a str) -> Result> { + pub fn new(language: Language, source: &str) -> Result { let mut error_offset = 0u32; let mut error_type: ffi::TSQueryError = 0; let bytes = source.as_bytes(); @@ -986,7 +983,7 @@ impl Query { let end_offset = suffix .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') .unwrap_or(source.len()); - let name = suffix.split_at(end_offset).0; + let name = suffix.split_at(end_offset).0.to_string(); match error_type { ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(name)), ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(name)), @@ -1002,7 +999,7 @@ impl Query { let capture_count = unsafe { ffi::ts_query_capture_count(ptr) }; let pattern_count = unsafe { ffi::ts_query_pattern_count(ptr) as usize }; let mut result = Query { - ptr, + ptr: unsafe { NonNull::new_unchecked(ptr) }, capture_names: Vec::with_capacity(capture_count as usize), predicates: Vec::with_capacity(pattern_count), }; @@ -1137,11 +1134,13 @@ impl Query { self.predicates.len(), ); } - unsafe { ffi::ts_query_start_byte_for_pattern(self.ptr, pattern_index as u32) as usize } + unsafe { + ffi::ts_query_start_byte_for_pattern(self.ptr.as_ptr(), pattern_index as u32) as usize + } } pub fn pattern_count(&self) -> usize { - unsafe { ffi::ts_query_pattern_count(self.ptr) as usize } + unsafe { ffi::ts_query_pattern_count(self.ptr.as_ptr()) as usize } } pub fn capture_names(&self) -> &[String] { @@ -1151,7 +1150,7 @@ impl Query { impl QueryCursor { pub fn new() -> Self { - QueryCursor(unsafe { ffi::ts_query_cursor_new() }) + QueryCursor(unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) }) } pub fn matches<'a>( @@ -1160,17 +1159,16 @@ impl QueryCursor { node: Node<'a>, mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, ) -> impl Iterator> + 'a { - unsafe { - ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); - } + let ptr = self.0.as_ptr(); + unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; std::iter::from_fn(move || -> Option> { loop { unsafe { let mut m = MaybeUninit::::uninit(); - if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) { + if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) { let m = m.assume_init(); let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); - if self.captures_match_condition( + if Self::captures_match_condition( query, captures, m.pattern_index as usize, @@ -1195,21 +1193,20 @@ impl QueryCursor { node: Node<'a>, mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, ) -> impl Iterator + 'a { - unsafe { - ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); - } + let ptr = self.0.as_ptr(); + unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; std::iter::from_fn(move || loop { unsafe { let mut m = MaybeUninit::::uninit(); let mut capture_index = 0u32; if ffi::ts_query_cursor_next_capture( - self.0, + ptr, m.as_mut_ptr(), &mut capture_index as *mut u32, ) { let m = m.assume_init(); let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); - if self.captures_match_condition( + if Self::captures_match_condition( query, captures, m.pattern_index as usize, @@ -1232,7 +1229,6 @@ impl QueryCursor { } fn captures_match_condition<'a>( - &self, query: &'a Query, captures: &'a [ffi::TSQueryCapture], pattern_index: usize, @@ -1268,14 +1264,14 @@ impl QueryCursor { pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_byte_range(self.0, start as u32, end as u32); + ffi::ts_query_cursor_set_byte_range(self.0.as_ptr(), start as u32, end as u32); } self } pub fn set_point_range(&mut self, start: Point, end: Point) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_point_range(self.0, start.into(), end.into()); + ffi::ts_query_cursor_set_point_range(self.0.as_ptr(), start.into(), end.into()); } self } @@ -1298,13 +1294,13 @@ impl PartialEq for Query { impl Drop for Query { fn drop(&mut self) { - unsafe { ffi::ts_query_delete(self.ptr) } + unsafe { ffi::ts_query_delete(self.ptr.as_ptr()) } } } impl Drop for QueryCursor { fn drop(&mut self) { - unsafe { ffi::ts_query_cursor_delete(self.0) } + unsafe { ffi::ts_query_cursor_delete(self.0.as_ptr()) } } } @@ -1520,3 +1516,10 @@ impl std::error::Error for PropertySheetError { } } } + +unsafe impl Send for Language {} +unsafe impl Send for Parser {} +unsafe impl Send for Query {} +unsafe impl Send for Tree {} +unsafe impl Sync for Language {} +unsafe impl Sync for Query {}