feat: add the ability to specify a custom decode function

This commit is contained in:
Amaan Qureshi 2024-10-30 23:49:42 -04:00
parent e27160b118
commit 500f4326d5
10 changed files with 347 additions and 16 deletions

17
Cargo.lock generated
View file

@ -509,6 +509,15 @@ version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "equivalent"
version = "1.0.1"
@ -1551,6 +1560,7 @@ dependencies = [
"ctrlc",
"dialoguer",
"dirs",
"encoding_rs",
"filetime",
"glob",
"heck",
@ -1586,6 +1596,7 @@ dependencies = [
"walkdir",
"wasmparser",
"webbrowser",
"widestring",
]
[[package]]
@ -2071,6 +2082,12 @@ dependencies = [
"web-sys",
]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]]
name = "winapi-util"
version = "0.1.9"

View file

@ -71,6 +71,8 @@ tree-sitter-loader.workspace = true
tree-sitter-tags.workspace = true
[dev-dependencies]
encoding_rs = "0.8.35"
widestring = "1.1.0"
tree_sitter_proc_macro = { path = "src/tests/proc_macro", package = "tree-sitter-tests-proc-macro" }
tempfile.workspace = true

View file

@ -4,7 +4,7 @@ use std::{
};
use tree_sitter::{
IncludedRangesError, InputEdit, LogType, ParseOptions, ParseState, Parser, Point, Range,
Decode, IncludedRangesError, InputEdit, LogType, ParseOptions, ParseState, Parser, Point, Range,
};
use tree_sitter_generate::{generate_parser_for_grammar, load_grammar_file};
use tree_sitter_proc_macro::retry;
@ -1646,6 +1646,176 @@ fn test_parsing_by_halting_at_offset() {
assert!(seen_byte_offsets.len() > 100);
}
#[test]
fn test_decode_utf32() {
use widestring::u32cstr;
let mut parser = Parser::new();
parser.set_language(&get_language("rust")).unwrap();
let utf32_text = u32cstr!("pub fn foo() { println!(\"€50\"); }");
let utf32_text = unsafe {
std::slice::from_raw_parts(utf32_text.as_ptr().cast::<u8>(), utf32_text.len() * 4)
};
struct U32Decoder;
impl Decode for U32Decoder {
fn decode(bytes: &[u8]) -> (i32, u32) {
if bytes.len() >= 4 {
#[cfg(target_endian = "big")]
{
(
i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
4,
)
}
#[cfg(target_endian = "little")]
{
(
i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
4,
)
}
} else {
println!("bad decode: {bytes:?}");
(0, 0)
}
}
}
let tree = parser
.parse_custom_encoding::<U32Decoder, _, _>(
&mut |offset, _| {
if offset < utf32_text.len() {
&utf32_text[offset..]
} else {
&[]
}
},
None,
None,
)
.unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))"
);
}
#[test]
fn test_decode_cp1252() {
use encoding_rs::WINDOWS_1252;
let mut parser = Parser::new();
parser.set_language(&get_language("rust")).unwrap();
let windows_1252_text = WINDOWS_1252.encode("pub fn foo() { println!(\"€50\"); }").0;
struct Cp1252Decoder;
impl Decode for Cp1252Decoder {
fn decode(bytes: &[u8]) -> (i32, u32) {
if !bytes.is_empty() {
let byte = bytes[0];
(byte as i32, 1)
} else {
(0, 0)
}
}
}
let tree = parser
.parse_custom_encoding::<Cp1252Decoder, _, _>(
&mut |offset, _| &windows_1252_text[offset..],
None,
None,
)
.unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))"
);
}
#[test]
fn test_decode_macintosh() {
use encoding_rs::MACINTOSH;
let mut parser = Parser::new();
parser.set_language(&get_language("rust")).unwrap();
let macintosh_text = MACINTOSH.encode("pub fn foo() { println!(\"€50\"); }").0;
struct MacintoshDecoder;
impl Decode for MacintoshDecoder {
fn decode(bytes: &[u8]) -> (i32, u32) {
if !bytes.is_empty() {
let byte = bytes[0];
(byte as i32, 1)
} else {
(0, 0)
}
}
}
let tree = parser
.parse_custom_encoding::<MacintoshDecoder, _, _>(
&mut |offset, _| &macintosh_text[offset..],
None,
None,
)
.unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))"
);
}
#[test]
fn test_decode_utf24le() {
let mut parser = Parser::new();
parser.set_language(&get_language("rust")).unwrap();
let mut utf24le_text = Vec::new();
for c in "pub fn foo() { println!(\"€50\"); }".chars() {
let code_point = c as u32;
utf24le_text.push((code_point & 0xFF) as u8);
utf24le_text.push(((code_point >> 8) & 0xFF) as u8);
utf24le_text.push(((code_point >> 16) & 0xFF) as u8);
}
struct Utf24LeDecoder;
impl Decode for Utf24LeDecoder {
fn decode(bytes: &[u8]) -> (i32, u32) {
if bytes.len() >= 3 {
(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], 0]), 3)
} else {
(0, 0)
}
}
}
let tree = parser
.parse_custom_encoding::<Utf24LeDecoder, _, _>(
&mut |offset, _| &utf24le_text[offset..],
None,
None,
)
.unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(source_file (function_item (visibility_modifier) name: (identifier) parameters: (parameters) body: (block (expression_statement (macro_invocation macro: (identifier) (token_tree (string_literal (string_content))))))))"
);
}
const fn simple_range(start: usize, end: usize) -> Range {
Range {
start_byte: start,

View file

@ -149,9 +149,22 @@ typedef struct {
uint32_t *bytes_read
);
TSInputEncoding encoding;
DecodeFunction decode;
} TSInput;
```
In the event that you want to decode text that is not encoded in UTF-8 or UTF16, then you can set the `decode` field of the input to your function that will decode text. The signature of the `DecodeFunction` is as follows:
```c
typedef uint32_t (*DecodeFunction)(
const uint8_t *string,
uint32_t length,
int32_t *code_point
);
```
The `string` argument is a pointer to the text to decode, which comes from the `read` function, and the `length` argument is the length of the `string`. The `code_point` argument is a pointer to an integer that represents the decoded code point, and should be written to in your `decode` callback. The function should return the number of bytes decoded.
### Syntax Nodes
Tree-sitter provides a [DOM](https://en.wikipedia.org/wiki/Document_Object_Model)-style interface for inspecting syntax trees. A syntax node's _type_ is a string that indicates which grammar rule the node represents.

View file

@ -35,9 +35,13 @@ pub struct TSQueryCursor {
pub struct TSLookaheadIterator {
_unused: [u8; 0],
}
pub type DecodeFunction = ::core::option::Option<
unsafe extern "C" fn(string: *const u8, length: u32, code_point: *mut i32) -> u32,
>;
pub const TSInputEncodingUTF8: TSInputEncoding = 0;
pub const TSInputEncodingUTF16LE: TSInputEncoding = 1;
pub const TSInputEncodingUTF16BE: TSInputEncoding = 2;
pub const TSInputEncodingCustom: TSInputEncoding = 3;
pub type TSInputEncoding = ::core::ffi::c_uint;
pub const TSSymbolTypeRegular: TSSymbolType = 0;
pub const TSSymbolTypeAnonymous: TSSymbolType = 1;
@ -71,6 +75,7 @@ pub struct TSInput {
) -> *const ::core::ffi::c_char,
>,
pub encoding: TSInputEncoding,
pub decode: DecodeFunction,
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
@ -212,7 +217,7 @@ extern "C" {
) -> *mut TSTree;
}
extern "C" {
#[doc = " Use the parser to parse some source code and create a syntax tree, with some options.\n\n See [`ts_parser_parse`] for more details."]
#[doc = " Use the parser to parse some source code and create a syntax tree, with some options.\n\n See [`ts_parser_parse`] for more details.\n\n See [`TSParseOptions`] for more details on the options."]
pub fn ts_parser_parse_with_options(
self_: *mut TSParser,
old_tree: *const TSTree,

View file

@ -210,6 +210,12 @@ type ParseProgressCallback<'a> = &'a mut dyn FnMut(&ParseState) -> bool;
/// A callback that receives the query state during query execution.
type QueryProgressCallback<'a> = &'a mut dyn FnMut(&QueryCursorState) -> bool;
pub trait Decode {
/// A callback that decodes the next code point from the input slice. It should return the code
/// point, and how many bytes were decoded.
fn decode(bytes: &[u8]) -> (i32, u32);
}
/// A stateful object for walking a syntax [`Tree`] efficiently.
#[doc(alias = "TSTreeCursor")]
pub struct TreeCursor<'cursor>(ffi::TSTreeCursor, PhantomData<&'cursor ()>);
@ -821,6 +827,7 @@ impl Parser {
payload: ptr::addr_of_mut!(payload).cast::<c_void>(),
read: Some(read::<T, F>),
encoding: ffi::TSInputEncodingUTF8,
decode: None,
};
let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr());
@ -956,6 +963,7 @@ impl Parser {
payload: core::ptr::addr_of_mut!(payload).cast::<c_void>(),
read: Some(read::<T, F>),
encoding: ffi::TSInputEncodingUTF16LE,
decode: None,
};
let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr());
@ -1070,6 +1078,113 @@ impl Parser {
payload: core::ptr::addr_of_mut!(payload).cast::<c_void>(),
read: Some(read::<T, F>),
encoding: ffi::TSInputEncodingUTF16BE,
decode: None,
};
let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr());
unsafe {
let c_new_tree = ffi::ts_parser_parse_with_options(
self.0.as_ptr(),
c_old_tree,
c_input,
parse_options,
);
NonNull::new(c_new_tree).map(Tree)
}
}
/// Parse text provided in chunks by a callback using a custom encoding.
/// This is useful for parsing text in encodings that are not UTF-8 or UTF-16.
///
/// # Arguments:
/// * `callback` A function that takes a byte offset and position and returns a slice of text
/// starting at that byte offset and position. The slices can be of any length. If the given
/// position is at the end of the text, the callback should return an empty slice.
/// * `old_tree` A previous syntax tree parsed from the same document. If the text of the
/// document has changed since `old_tree` was created, then you must edit `old_tree` to match
/// the new text using [`Tree::edit`].
/// * `options` Options for parsing the text. This can be used to set a progress callback.
///
/// Additionally, you must set the generic parameter [`D`] to a type that implements the
/// [`Decode`] trait. This trait has a single method, [`decode`](Decode::decode), which takes a
/// slice of bytes and returns a tuple of the code point and the number of bytes consumed.
/// The `decode` method should return `-1` for the code point if decoding fails.
pub fn parse_custom_encoding<D: Decode, T: AsRef<[u8]>, F: FnMut(usize, Point) -> T>(
&mut self,
callback: &mut F,
old_tree: Option<&Tree>,
options: Option<ParseOptions>,
) -> Option<Tree> {
type Payload<'a, F, T> = (&'a mut F, Option<T>);
unsafe extern "C" fn progress(state: *mut ffi::TSParseState) -> bool {
let callback = (*state)
.payload
.cast::<ParseProgressCallback>()
.as_mut()
.unwrap();
callback(&ParseState::from_raw(state))
}
// At compile time, create a C-compatible callback that calls the custom `decode` method.
unsafe extern "C" fn decode_fn<D: Decode>(
data: *const u8,
len: u32,
code_point: *mut i32,
) -> u32 {
let (c, len) = D::decode(std::slice::from_raw_parts(data, len as usize));
if let Some(code_point) = code_point.as_mut() {
*code_point = c;
}
len
}
// This C function is passed to Tree-sitter as the input callback.
unsafe extern "C" fn read<T: AsRef<[u8]>, F: FnMut(usize, Point) -> T>(
payload: *mut c_void,
byte_offset: u32,
position: ffi::TSPoint,
bytes_read: *mut u32,
) -> *const c_char {
let (callback, text) = payload.cast::<Payload<F, T>>().as_mut().unwrap();
*text = Some(callback(byte_offset as usize, position.into()));
let slice = text.as_ref().unwrap().as_ref();
*bytes_read = slice.len() as u32;
slice.as_ptr().cast::<c_char>()
}
let empty_options = ffi::TSParseOptions {
payload: ptr::null_mut(),
progress_callback: None,
};
let parse_options = if let Some(options) = options {
if let Some(mut cb) = options.progress_callback {
ffi::TSParseOptions {
payload: core::ptr::addr_of_mut!(cb).cast::<c_void>(),
progress_callback: Some(progress),
}
} else {
empty_options
}
} else {
empty_options
};
// A pointer to this payload is passed on every call to the `read` C function.
// The payload contains two things:
// 1. A reference to the rust `callback`.
// 2. The text that was returned from the previous call to `callback`. This allows the
// callback to return owned values like vectors.
let mut payload: Payload<F, T> = (callback, None);
let c_input = ffi::TSInput {
payload: core::ptr::addr_of_mut!(payload).cast::<c_void>(),
read: Some(read::<T, F>),
encoding: ffi::TSInputEncodingCustom,
// Use this custom decode callback
decode: Some(decode_fn::<D>),
};
let c_old_tree = old_tree.map_or(ptr::null_mut(), |t| t.0.as_ptr());

View file

@ -48,10 +48,20 @@ typedef struct TSQuery TSQuery;
typedef struct TSQueryCursor TSQueryCursor;
typedef struct TSLookaheadIterator TSLookaheadIterator;
// This function signature reads one code point from the given string,
// returning the number of bytes consumed. It should write the code point
// to the `code_point` pointer, or write -1 if the input is invalid.
typedef uint32_t (*DecodeFunction)(
const uint8_t *string,
uint32_t length,
int32_t *code_point
);
typedef enum TSInputEncoding {
TSInputEncodingUTF8,
TSInputEncodingUTF16LE,
TSInputEncodingUTF16BE,
TSInputEncodingCustom
} TSInputEncoding;
typedef enum TSSymbolType {
@ -77,6 +87,7 @@ typedef struct TSInput {
void *payload;
const char *(*read)(void *payload, uint32_t byte_index, TSPoint position, uint32_t *bytes_read);
TSInputEncoding encoding;
DecodeFunction decode;
} TSInput;
typedef struct TSParseState {
@ -297,6 +308,8 @@ TSTree *ts_parser_parse(
* Use the parser to parse some source code and create a syntax tree, with some options.
*
* See [`ts_parser_parse`] for more details.
*
* See [`TSParseOptions`] for more details on the options.
*/
TSTree* ts_parser_parse_with_options(
TSParser *self,

View file

@ -1,9 +1,11 @@
#include <stdio.h>
#include "./lexer.h"
#include "./subtree.h"
#include "./length.h"
#include "./lexer.h"
#include "./unicode.h"
#include "tree_sitter/api.h"
#include <stdarg.h>
#include <stdio.h>
#define LOG(message, character) \
if (self->logger.log) { \
@ -112,9 +114,10 @@ static void ts_lexer__get_lookahead(Lexer *self) {
}
const uint8_t *chunk = (const uint8_t *)self->chunk + position_in_chunk;
UnicodeDecodeFunction decode =
self->input.encoding == TSInputEncodingUTF8 ? ts_decode_utf8 :
self->input.encoding == TSInputEncodingUTF16LE ? ts_decode_utf16_le : ts_decode_utf16_be;
DecodeFunction decode =
self->input.encoding == TSInputEncodingUTF8 ? ts_decode_utf8 :
self->input.encoding == TSInputEncodingUTF16LE ? ts_decode_utf16_le :
self->input.encoding == TSInputEncodingUTF16BE ? ts_decode_utf16_be : self->input.decode;
self->lookahead_size = decode(chunk, size, &self->data.lookahead);

View file

@ -2163,6 +2163,7 @@ TSTree *ts_parser_parse_string_encoding(
&input,
ts_string_input_read,
encoding,
NULL,
});
}

View file

@ -38,14 +38,6 @@ extern "C" {
static const int32_t TS_DECODE_ERROR = U_SENTINEL;
// These functions read one unicode code point from the given string,
// returning the number of bytes consumed.
typedef uint32_t (*UnicodeDecodeFunction)(
const uint8_t *string,
uint32_t length,
int32_t *code_point
);
static inline uint32_t ts_decode_utf8(
const uint8_t *string,
uint32_t length,