From 500f4326d5565388acccd1a33bfc5ad25ff698c7 Mon Sep 17 00:00:00 2001 From: Amaan Qureshi Date: Wed, 30 Oct 2024 23:49:42 -0400 Subject: [PATCH] feat: add the ability to specify a custom decode function --- Cargo.lock | 17 ++++ cli/Cargo.toml | 2 + cli/src/tests/parser_test.rs | 172 +++++++++++++++++++++++++++++++- docs/section-2-using-parsers.md | 13 +++ lib/binding_rust/bindings.rs | 7 +- lib/binding_rust/lib.rs | 115 +++++++++++++++++++++ lib/include/tree_sitter/api.h | 13 +++ lib/src/lexer.c | 15 +-- lib/src/parser.c | 1 + lib/src/unicode.h | 8 -- 10 files changed, 347 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a66fb201..c2a46231 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -509,6 +509,15 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -1551,6 +1560,7 @@ dependencies = [ "ctrlc", "dialoguer", "dirs", + "encoding_rs", "filetime", "glob", "heck", @@ -1586,6 +1596,7 @@ dependencies = [ "walkdir", "wasmparser", "webbrowser", + "widestring", ] [[package]] @@ -2071,6 +2082,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "widestring" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311" + [[package]] name = "winapi-util" version = "0.1.9" diff --git a/cli/Cargo.toml b/cli/Cargo.toml index c96dcb78..c6dfa9e8 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -71,6 +71,8 @@ tree-sitter-loader.workspace = true tree-sitter-tags.workspace = true [dev-dependencies] +encoding_rs = "0.8.35" +widestring = "1.1.0" tree_sitter_proc_macro = { path = "src/tests/proc_macro", package = "tree-sitter-tests-proc-macro" } tempfile.workspace = true diff --git a/cli/src/tests/parser_test.rs b/cli/src/tests/parser_test.rs index 776ed750..1f2bc6e7 100644 --- a/cli/src/tests/parser_test.rs +++ b/cli/src/tests/parser_test.rs @@ -4,7 +4,7 @@ use std::{ }; use tree_sitter::{ - IncludedRangesError, InputEdit, LogType, ParseOptions, ParseState, Parser, Point, Range, + Decode, IncludedRangesError, InputEdit, LogType, ParseOptions, ParseState, Parser, Point, Range, }; use tree_sitter_generate::{generate_parser_for_grammar, load_grammar_file}; use tree_sitter_proc_macro::retry; @@ -1646,6 +1646,176 @@ fn test_parsing_by_halting_at_offset() { assert!(seen_byte_offsets.len() > 100); } +#[test] +fn test_decode_utf32() { + use widestring::u32cstr; + + let mut parser = Parser::new(); + parser.set_language(&get_language("rust")).unwrap(); + + let utf32_text = u32cstr!("pub fn foo() { println!(\"€50\"); }"); + let utf32_text = unsafe { + std::slice::from_raw_parts(utf32_text.as_ptr().cast::(), utf32_text.len() * 4) + }; + + struct U32Decoder; + + impl Decode for U32Decoder { + fn decode(bytes: &[u8]) -> (i32, u32) { + if bytes.len() >= 4 { + #[cfg(target_endian = "big")] + { + ( + i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), + 4, + ) + } + + #[cfg(target_endian = "little")] + { + ( + i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), + 4, + ) + } + } else { + println!("bad decode: {bytes:?}"); + (0, 0) + } + } + } + + let tree = parser + .parse_custom_encoding::( + &mut |offset, _| { + if offset < utf32_text.len() { + &utf32_text[offset..] + } else { + &[] + } + }, + None, + None, + ) + .unwrap(); + + assert_eq!( + tree.root_node().to_sexp(), + "(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))" + ); +} + +#[test] +fn test_decode_cp1252() { + use encoding_rs::WINDOWS_1252; + + let mut parser = Parser::new(); + parser.set_language(&get_language("rust")).unwrap(); + + let windows_1252_text = WINDOWS_1252.encode("pub fn foo() { println!(\"€50\"); }").0; + + struct Cp1252Decoder; + + impl Decode for Cp1252Decoder { + fn decode(bytes: &[u8]) -> (i32, u32) { + if !bytes.is_empty() { + let byte = bytes[0]; + (byte as i32, 1) + } else { + (0, 0) + } + } + } + + let tree = parser + .parse_custom_encoding::( + &mut |offset, _| &windows_1252_text[offset..], + None, + None, + ) + .unwrap(); + + assert_eq!( + tree.root_node().to_sexp(), + "(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))" + ); +} + +#[test] +fn test_decode_macintosh() { + use encoding_rs::MACINTOSH; + + let mut parser = Parser::new(); + parser.set_language(&get_language("rust")).unwrap(); + + let macintosh_text = MACINTOSH.encode("pub fn foo() { println!(\"€50\"); }").0; + + struct MacintoshDecoder; + + impl Decode for MacintoshDecoder { + fn decode(bytes: &[u8]) -> (i32, u32) { + if !bytes.is_empty() { + let byte = bytes[0]; + (byte as i32, 1) + } else { + (0, 0) + } + } + } + + let tree = parser + .parse_custom_encoding::( + &mut |offset, _| &macintosh_text[offset..], + None, + None, + ) + .unwrap(); + + assert_eq!( + tree.root_node().to_sexp(), + "(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))" + ); +} + +#[test] +fn test_decode_utf24le() { + let mut parser = Parser::new(); + parser.set_language(&get_language("rust")).unwrap(); + + let mut utf24le_text = Vec::new(); + for c in "pub fn foo() { println!(\"€50\"); }".chars() { + let code_point = c as u32; + utf24le_text.push((code_point & 0xFF) as u8); + utf24le_text.push(((code_point >> 8) & 0xFF) as u8); + utf24le_text.push(((code_point >> 16) & 0xFF) as u8); + } + + struct Utf24LeDecoder; + + impl Decode for Utf24LeDecoder { + fn decode(bytes: &[u8]) -> (i32, u32) { + if bytes.len() >= 3 { + (i32::from_le_bytes([bytes[0], bytes[1], bytes[2], 0]), 3) + } else { + (0, 0) + } + } + } + + let tree = parser + .parse_custom_encoding::( + &mut |offset, _| &utf24le_text[offset..], + None, + None, + ) + .unwrap(); + + assert_eq!( + tree.root_node().to_sexp(), + "(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))" + ); +} + const fn simple_range(start: usize, end: usize) -> Range { Range { start_byte: start, diff --git a/docs/section-2-using-parsers.md b/docs/section-2-using-parsers.md index 266ed2a7..688af1ab 100644 --- a/docs/section-2-using-parsers.md +++ b/docs/section-2-using-parsers.md @@ -149,9 +149,22 @@ typedef struct { uint32_t *bytes_read ); TSInputEncoding encoding; + DecodeFunction decode; } TSInput; ``` +In the event that you want to decode text that is not encoded in UTF-8 or UTF16, then you can set the `decode` field of the input to your function that will decode text. The signature of the `DecodeFunction` is as follows: + +```c +typedef uint32_t (*DecodeFunction)( + const uint8_t *string, + uint32_t length, + int32_t *code_point +); +``` + +The `string` argument is a pointer to the text to decode, which comes from the `read` function, and the `length` argument is the length of the `string`. The `code_point` argument is a pointer to an integer that represents the decoded code point, and should be written to in your `decode` callback. The function should return the number of bytes decoded. + ### Syntax Nodes Tree-sitter provides a [DOM](https://en.wikipedia.org/wiki/Document_Object_Model)-style interface for inspecting syntax trees. A syntax node's _type_ is a string that indicates which grammar rule the node represents. diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 7a26a075..692194ce 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -35,9 +35,13 @@ pub struct TSQueryCursor { pub struct TSLookaheadIterator { _unused: [u8; 0], } +pub type DecodeFunction = ::core::option::Option< + unsafe extern "C" fn(string: *const u8, length: u32, code_point: *mut i32) -> u32, +>; pub const TSInputEncodingUTF8: TSInputEncoding = 0; pub const TSInputEncodingUTF16LE: TSInputEncoding = 1; pub const TSInputEncodingUTF16BE: TSInputEncoding = 2; +pub const TSInputEncodingCustom: TSInputEncoding = 3; pub type TSInputEncoding = ::core::ffi::c_uint; pub const TSSymbolTypeRegular: TSSymbolType = 0; pub const TSSymbolTypeAnonymous: TSSymbolType = 1; @@ -71,6 +75,7 @@ pub struct TSInput { ) -> *const ::core::ffi::c_char, >, pub encoding: TSInputEncoding, + pub decode: DecodeFunction, } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -212,7 +217,7 @@ extern "C" { ) -> *mut TSTree; } extern "C" { - #[doc = " Use the parser to parse some source code and create a syntax tree, with some options.\n\n See [`ts_parser_parse`] for more details."] + #[doc = " Use the parser to parse some source code and create a syntax tree, with some options.\n\n See [`ts_parser_parse`] for more details.\n\n See [`TSParseOptions`] for more details on the options."] pub fn ts_parser_parse_with_options( self_: *mut TSParser, old_tree: *const TSTree, diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index b16af5ad..462bddb6 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -210,6 +210,12 @@ type ParseProgressCallback<'a> = &'a mut dyn FnMut(&ParseState) -> bool; /// A callback that receives the query state during query execution. type QueryProgressCallback<'a> = &'a mut dyn FnMut(&QueryCursorState) -> bool; +pub trait Decode { + /// A callback that decodes the next code point from the input slice. It should return the code + /// point, and how many bytes were decoded. + fn decode(bytes: &[u8]) -> (i32, u32); +} + /// A stateful object for walking a syntax [`Tree`] efficiently. #[doc(alias = "TSTreeCursor")] pub struct TreeCursor<'cursor>(ffi::TSTreeCursor, PhantomData<&'cursor ()>); @@ -821,6 +827,7 @@ impl Parser { payload: ptr::addr_of_mut!(payload).cast::(), read: Some(read::), encoding: ffi::TSInputEncodingUTF8, + decode: None, }; let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr()); @@ -956,6 +963,7 @@ impl Parser { payload: core::ptr::addr_of_mut!(payload).cast::(), read: Some(read::), encoding: ffi::TSInputEncodingUTF16LE, + decode: None, }; let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr()); @@ -1070,6 +1078,113 @@ impl Parser { payload: core::ptr::addr_of_mut!(payload).cast::(), read: Some(read::), encoding: ffi::TSInputEncodingUTF16BE, + decode: None, + }; + + let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr()); + unsafe { + let c_new_tree = ffi::ts_parser_parse_with_options( + self.0.as_ptr(), + c_old_tree, + c_input, + parse_options, + ); + + NonNull::new(c_new_tree).map(Tree) + } + } + + /// Parse text provided in chunks by a callback using a custom encoding. + /// This is useful for parsing text in encodings that are not UTF-8 or UTF-16. + /// + /// # Arguments: + /// * `callback` A function that takes a byte offset and position and returns a slice of text + /// starting at that byte offset and position. The slices can be of any length. If the given + /// position is at the end of the text, the callback should return an empty slice. + /// * `old_tree` A previous syntax tree parsed from the same document. If the text of the + /// document has changed since `old_tree` was created, then you must edit `old_tree` to match + /// the new text using [`Tree::edit`]. + /// * `options` Options for parsing the text. This can be used to set a progress callback. + /// + /// Additionally, you must set the generic parameter [`D`] to a type that implements the + /// [`Decode`] trait. This trait has a single method, [`decode`](Decode::decode), which takes a + /// slice of bytes and returns a tuple of the code point and the number of bytes consumed. + /// The `decode` method should return `-1` for the code point if decoding fails. + pub fn parse_custom_encoding, F: FnMut(usize, Point) -> T>( + &mut self, + callback: &mut F, + old_tree: Option<&Tree>, + options: Option, + ) -> Option { + type Payload<'a, F, T> = (&'a mut F, Option); + + unsafe extern "C" fn progress(state: *mut ffi::TSParseState) -> bool { + let callback = (*state) + .payload + .cast::() + .as_mut() + .unwrap(); + callback(&ParseState::from_raw(state)) + } + + // At compile time, create a C-compatible callback that calls the custom `decode` method. + unsafe extern "C" fn decode_fn( + data: *const u8, + len: u32, + code_point: *mut i32, + ) -> u32 { + let (c, len) = D::decode(std::slice::from_raw_parts(data, len as usize)); + if let Some(code_point) = code_point.as_mut() { + *code_point = c; + } + len + } + + // This C function is passed to Tree-sitter as the input callback. + unsafe extern "C" fn read, F: FnMut(usize, Point) -> T>( + payload: *mut c_void, + byte_offset: u32, + position: ffi::TSPoint, + bytes_read: *mut u32, + ) -> *const c_char { + let (callback, text) = payload.cast::>().as_mut().unwrap(); + *text = Some(callback(byte_offset as usize, position.into())); + let slice = text.as_ref().unwrap().as_ref(); + *bytes_read = slice.len() as u32; + slice.as_ptr().cast::() + } + + let empty_options = ffi::TSParseOptions { + payload: ptr::null_mut(), + progress_callback: None, + }; + + let parse_options = if let Some(options) = options { + if let Some(mut cb) = options.progress_callback { + ffi::TSParseOptions { + payload: core::ptr::addr_of_mut!(cb).cast::(), + progress_callback: Some(progress), + } + } else { + empty_options + } + } else { + empty_options + }; + + // A pointer to this payload is passed on every call to the `read` C function. + // The payload contains two things: + // 1. A reference to the rust `callback`. + // 2. The text that was returned from the previous call to `callback`. This allows the + // callback to return owned values like vectors. + let mut payload: Payload = (callback, None); + + let c_input = ffi::TSInput { + payload: core::ptr::addr_of_mut!(payload).cast::(), + read: Some(read::), + encoding: ffi::TSInputEncodingCustom, + // Use this custom decode callback + decode: Some(decode_fn::), }; let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr()); diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 31a39d46..e4650a4b 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -48,10 +48,20 @@ typedef struct TSQuery TSQuery; typedef struct TSQueryCursor TSQueryCursor; typedef struct TSLookaheadIterator TSLookaheadIterator; +// This function signature reads one code point from the given string, +// returning the number of bytes consumed. It should write the code point +// to the `code_point` pointer, or write -1 if the input is invalid. +typedef uint32_t (*DecodeFunction)( + const uint8_t *string, + uint32_t length, + int32_t *code_point +); + typedef enum TSInputEncoding { TSInputEncodingUTF8, TSInputEncodingUTF16LE, TSInputEncodingUTF16BE, + TSInputEncodingCustom } TSInputEncoding; typedef enum TSSymbolType { @@ -77,6 +87,7 @@ typedef struct TSInput { void *payload; const char *(*read)(void *payload, uint32_t byte_index, TSPoint position, uint32_t *bytes_read); TSInputEncoding encoding; + DecodeFunction decode; } TSInput; typedef struct TSParseState { @@ -297,6 +308,8 @@ TSTree *ts_parser_parse( * Use the parser to parse some source code and create a syntax tree, with some options. * * See [`ts_parser_parse`] for more details. + * + * See [`TSParseOptions`] for more details on the options. */ TSTree* ts_parser_parse_with_options( TSParser *self, diff --git a/lib/src/lexer.c b/lib/src/lexer.c index 21448a2e..f181946a 100644 --- a/lib/src/lexer.c +++ b/lib/src/lexer.c @@ -1,9 +1,11 @@ -#include -#include "./lexer.h" -#include "./subtree.h" #include "./length.h" +#include "./lexer.h" #include "./unicode.h" + +#include "tree_sitter/api.h" + #include +#include #define LOG(message, character) \ if (self->logger.log) { \ @@ -112,9 +114,10 @@ static void ts_lexer__get_lookahead(Lexer *self) { } const uint8_t *chunk = (const uint8_t *)self->chunk + position_in_chunk; - UnicodeDecodeFunction decode = - self->input.encoding == TSInputEncodingUTF8 ? ts_decode_utf8 : - self->input.encoding == TSInputEncodingUTF16LE ? ts_decode_utf16_le : ts_decode_utf16_be; + DecodeFunction decode = + self->input.encoding == TSInputEncodingUTF8 ? ts_decode_utf8 : + self->input.encoding == TSInputEncodingUTF16LE ? ts_decode_utf16_le : + self->input.encoding == TSInputEncodingUTF16BE ? ts_decode_utf16_be : self->input.decode; self->lookahead_size = decode(chunk, size, &self->data.lookahead); diff --git a/lib/src/parser.c b/lib/src/parser.c index 9c5ddeee..4c2976cd 100644 --- a/lib/src/parser.c +++ b/lib/src/parser.c @@ -2163,6 +2163,7 @@ TSTree *ts_parser_parse_string_encoding( &input, ts_string_input_read, encoding, + NULL, }); } diff --git a/lib/src/unicode.h b/lib/src/unicode.h index efeee1fd..0fba3f29 100644 --- a/lib/src/unicode.h +++ b/lib/src/unicode.h @@ -38,14 +38,6 @@ extern "C" { static const int32_t TS_DECODE_ERROR = U_SENTINEL; -// These functions read one unicode code point from the given string, -// returning the number of bytes consumed. -typedef uint32_t (*UnicodeDecodeFunction)( - const uint8_t *string, - uint32_t length, - int32_t *code_point -); - static inline uint32_t ts_decode_utf8( const uint8_t *string, uint32_t length,