From fe7c74e7aa90e4f935c466763c84aad4449742cb Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 9 Sep 2019 15:41:13 -0700 Subject: [PATCH] Start work on an API for querying trees --- cli/src/tests/mod.rs | 1 + cli/src/tests/query_test.rs | 216 +++++++++ lib/binding_rust/bindings.rs | 80 ++++ lib/binding_rust/lib.rs | 130 +++++- lib/include/tree_sitter/api.h | 115 +++++ lib/src/bits.h | 25 ++ lib/src/lib.c | 1 + lib/src/query.c | 810 ++++++++++++++++++++++++++++++++++ lib/src/tree_cursor.c | 63 ++- lib/src/tree_cursor.h | 1 + 10 files changed, 1430 insertions(+), 12 deletions(-) create mode 100644 cli/src/tests/query_test.rs create mode 100644 lib/src/bits.h create mode 100644 lib/src/query.c diff --git a/cli/src/tests/mod.rs b/cli/src/tests/mod.rs index 143e8297..1a2a71ff 100644 --- a/cli/src/tests/mod.rs +++ b/cli/src/tests/mod.rs @@ -4,4 +4,5 @@ mod highlight_test; mod node_test; mod parser_test; mod properties_test; +mod query_test; mod tree_test; diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs new file mode 100644 index 00000000..32adba62 --- /dev/null +++ b/cli/src/tests/query_test.rs @@ -0,0 +1,216 @@ +use super::helpers::allocations; +use super::helpers::fixtures::get_language; +use tree_sitter::{Parser, Query, QueryError, QueryMatch}; + +#[test] +fn test_query_errors_on_invalid_syntax() { + allocations::start_recording(); + + let language = get_language("javascript"); + + assert!(Query::new(language, "(if_statement)").is_ok()); + assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok()); + + // Mismatched parens + assert_eq!( + Query::new(language, "(if_statement"), + Err(QueryError::Syntax(13)) + ); + assert_eq!( + Query::new(language, "(if_statement))"), + Err(QueryError::Syntax(14)) + ); + + // Return an error at the *beginning* of a bare identifier not followed a colon. + // If there's a colon but no pattern, return an error at the end of the colon. + assert_eq!( + Query::new(language, "(if_statement identifier)"), + Err(QueryError::Syntax(14)) + ); + assert_eq!( + Query::new(language, "(if_statement condition:)"), + Err(QueryError::Syntax(24)) + ); + + assert_eq!( + Query::new(language, "(if_statement condition:)"), + Err(QueryError::Syntax(24)) + ); + + allocations::stop_recording(); +} + +#[test] +fn test_query_errors_on_invalid_symbols() { + allocations::start_recording(); + + let language = get_language("javascript"); + + assert_eq!( + Query::new(language, "(non_existent1)"), + Err(QueryError::NodeType("non_existent1")) + ); + assert_eq!( + Query::new(language, "(if_statement (non_existent2))"), + Err(QueryError::NodeType("non_existent2")) + ); + assert_eq!( + Query::new(language, "(if_statement condition: (non_existent3))"), + Err(QueryError::NodeType("non_existent3")) + ); + assert_eq!( + Query::new(language, "(if_statement not_a_field: (identifier))"), + Err(QueryError::Field("not_a_field")) + ); + + allocations::stop_recording(); +} + +#[test] +fn test_query_capture_names() { + allocations::start_recording(); + + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (if_statement + condition: (binary_expression + left: * @left-operand + operator: "||" + right: * @right-operand) + consequence: (statement_block) @body) + + (while_statement + condition:* @loop-condition) + "#, + ) + .unwrap(); + + assert_eq!( + query.capture_names(), + &[ + "left-operand".to_string(), + "right-operand".to_string(), + "body".to_string(), + "loop-condition".to_string(), + ] + ); + + drop(query); + allocations::stop_recording(); +} + +#[test] +fn test_query_exec_with_simple_pattern() { + allocations::start_recording(); + + let language = get_language("javascript"); + let query = Query::new( + language, + "(function_declaration name: (identifier) @fn-name)", + ) + .unwrap(); + + let source = "function one() { two(); function three() {} }"; + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + + let context = query.context(); + let matches = context.exec(tree.root_node()); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("fn-name", "one")]), + (0, vec![("fn-name", "three")]) + ], + ); + + drop(context); + drop(parser); + drop(query); + drop(tree); + allocations::stop_recording(); +} + +#[test] +fn test_query_exec_with_multiple_matches_same_root() { + allocations::start_recording(); + + let language = get_language("javascript"); + let query = Query::new( + language, + "(class_declaration + name: (identifier) @the-class-name + (class_body + (method_definition + name: (property_identifier) @the-method-name)))", + ) + .unwrap(); + + let source = " + class Person { + // the constructor + constructor(name) { this.name = name; } + + // the getter + getFullName() { return this.name; } + } + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let context = query.context(); + let matches = context.exec(tree.root_node()); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + ( + 0, + vec![ + ("the-class-name", "Person"), + ("the-method-name", "constructor") + ] + ), + ( + 0, + vec![ + ("the-class-name", "Person"), + ("the-method-name", "getFullName") + ] + ), + ], + ); + + drop(context); + drop(parser); + drop(query); + drop(tree); + allocations::stop_recording(); +} + +fn collect_matches<'a>( + matches: impl Iterator>, + query: &'a Query, + source: &'a str, +) -> Vec<(usize, Vec<(&'a str, &'a str)>)> { + matches + .map(|m| { + ( + m.pattern_index(), + m.captures() + .map(|(capture_id, node)| { + ( + query.capture_names()[capture_id].as_str(), + node.utf8_text(source.as_bytes()).unwrap(), + ) + }) + .collect(), + ) + }) + .collect() +} diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index a71b297e..53b77405 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -19,6 +19,16 @@ pub struct TSParser { pub struct TSTree { _unused: [u8; 0], } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TSQuery { + _unused: [u8; 0], +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TSQueryContext { + _unused: [u8; 0], +} pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0; pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1; pub type TSInputEncoding = u32; @@ -93,6 +103,17 @@ pub struct TSTreeCursor { pub id: *const ::std::os::raw::c_void, pub context: [u32; 2usize], } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TSQueryCapture { + pub node: TSNode, + pub index: u32, +} +pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0; +pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1; +pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; +pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; +pub type TSQueryError = u32; extern "C" { #[doc = " Create a new parser."] pub fn ts_parser_new() -> *mut TSParser; @@ -538,6 +559,65 @@ extern "C" { extern "C" { pub fn ts_tree_cursor_copy(arg1: *const TSTreeCursor) -> TSTreeCursor; } +extern "C" { + #[doc = " Create a new query based on a given language and string containing"] + #[doc = " one or more S-expression patterns."] + #[doc = ""] + #[doc = " If all of the given patterns are valid, this returns a `TSQuery`."] + #[doc = " If a pattern is invalid, this returns `NULL`, and provides two pieces"] + #[doc = " of information about the problem:"] + #[doc = " 1. The byte offset of the error is written to the `error_offset` parameter."] + #[doc = " 2. The type of error is written to the `error_type` parameter."] + pub fn ts_query_new( + arg1: *const TSLanguage, + source: *const ::std::os::raw::c_char, + source_len: u32, + error_offset: *mut u32, + error_type: *mut TSQueryError, + ) -> *mut TSQuery; +} +extern "C" { + #[doc = " Delete a query, freeing all of the memory that it used."] + pub fn ts_query_delete(arg1: *mut TSQuery); +} +extern "C" { + pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32; +} +extern "C" { + pub fn ts_query_capture_name_for_id( + self_: *const TSQuery, + index: u32, + length: *mut u32, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + pub fn ts_query_capture_id_for_name( + self_: *const TSQuery, + name: *const ::std::os::raw::c_char, + length: u32, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn ts_query_context_new(arg1: *const TSQuery) -> *mut TSQueryContext; +} +extern "C" { + pub fn ts_query_context_delete(arg1: *mut TSQueryContext); +} +extern "C" { + pub fn ts_query_context_exec(arg1: *mut TSQueryContext, arg2: TSNode); +} +extern "C" { + pub fn ts_query_context_next(arg1: *mut TSQueryContext) -> bool; +} +extern "C" { + pub fn ts_query_context_matched_pattern_index(arg1: *const TSQueryContext) -> u32; +} +extern "C" { + pub fn ts_query_context_matched_captures( + arg1: *const TSQueryContext, + arg2: *mut u32, + ) -> *const TSQueryCapture; +} extern "C" { #[doc = " Get the number of distinct node types in the language."] pub fn ts_language_symbol_count(arg1: *const TSLanguage) -> u32; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 4c34d202..80e56ba9 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -17,7 +17,7 @@ use std::ffi::CStr; use std::marker::PhantomData; use std::os::raw::{c_char, c_void}; use std::sync::atomic::AtomicUsize; -use std::{fmt, ptr, str, u16}; +use std::{char, fmt, ptr, slice, str, u16}; pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION; pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h"); @@ -136,6 +136,23 @@ pub struct TreePropertyCursor<'a, P> { source: &'a [u8], } +#[derive(Debug)] +pub struct Query { + ptr: *mut ffi::TSQuery, + capture_names: Vec, +} + +pub struct QueryContext<'a>(*mut ffi::TSQueryContext, PhantomData<&'a ()>); + +pub struct QueryMatch<'a>(&'a QueryContext<'a>); + +#[derive(Debug, PartialEq, Eq)] +pub enum QueryError<'a> { + Syntax(usize), + NodeType(&'a str), + Field(&'a str), +} + impl Language { pub fn version(&self) -> usize { unsafe { ffi::ts_language_version(self.0) as usize } @@ -921,6 +938,117 @@ impl<'a, P> TreePropertyCursor<'a, P> { } } +impl Query { + pub fn new(language: Language, source: &str) -> Result { + let mut error_offset = 0u32; + let mut error_type: ffi::TSQueryError = 0; + let bytes = source.as_bytes(); + let ptr = unsafe { + ffi::ts_query_new( + language.0, + bytes.as_ptr() as *const c_char, + bytes.len() as u32, + &mut error_offset as *mut u32, + &mut error_type as *mut ffi::TSQueryError, + ) + }; + if ptr.is_null() { + let offset = error_offset as usize; + Err(match error_type { + ffi::TSQueryError_TSQueryErrorNodeType | ffi::TSQueryError_TSQueryErrorField => { + let suffix = source.split_at(offset).1; + let end_offset = suffix + .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') + .unwrap_or(source.len()); + let name = suffix.split_at(end_offset).0; + if error_type == ffi::TSQueryError_TSQueryErrorNodeType { + QueryError::NodeType(name) + } else { + QueryError::Field(name) + } + } + _ => QueryError::Syntax(offset), + }) + } else { + let capture_count = unsafe { ffi::ts_query_capture_count(ptr) }; + let capture_names = (0..capture_count) + .map(|i| unsafe { + let mut length = 0u32; + let name = + ffi::ts_query_capture_name_for_id(ptr, i as u32, &mut length as *mut u32) + as *const u8; + let name = slice::from_raw_parts(name, length as usize); + let name = str::from_utf8_unchecked(name); + name.to_string() + }) + .collect(); + Ok(Query { ptr, capture_names }) + } + } + + pub fn capture_names(&self) -> &[String] { + &self.capture_names + } + + pub fn context(&self) -> QueryContext { + let context = unsafe { ffi::ts_query_context_new(self.ptr) }; + QueryContext(context, PhantomData) + } +} + +impl<'a> QueryContext<'a> { + pub fn exec(&'a self, node: Node<'a>) -> impl Iterator> + 'a { + unsafe { + ffi::ts_query_context_exec(self.0, node.0); + } + std::iter::from_fn(move || -> Option> { + unsafe { + if ffi::ts_query_context_next(self.0) { + Some(QueryMatch(self)) + } else { + None + } + } + }) + } +} + +impl<'a> QueryMatch<'a> { + pub fn pattern_index(&self) -> usize { + unsafe { ffi::ts_query_context_matched_pattern_index((self.0).0) as usize } + } + + pub fn captures(&self) -> impl ExactSizeIterator { + unsafe { + let mut capture_count = 0u32; + let captures = + ffi::ts_query_context_matched_captures((self.0).0, &mut capture_count as *mut u32); + let captures = slice::from_raw_parts(captures, capture_count as usize); + captures + .iter() + .map(move |capture| (capture.index as usize, Node::new(capture.node).unwrap())) + } + } +} + +impl PartialEq for Query { + fn eq(&self, other: &Self) -> bool { + self.ptr == other.ptr + } +} + +impl Drop for Query { + fn drop(&mut self) { + unsafe { ffi::ts_query_delete(self.ptr) } + } +} + +impl<'a> Drop for QueryContext<'a> { + fn drop(&mut self) { + unsafe { ffi::ts_query_context_delete(self.0) } + } +} + impl Point { pub fn new(row: usize, column: usize) -> Self { Point { row, column } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index d39d0521..ad991818 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -26,6 +26,8 @@ typedef uint16_t TSFieldId; typedef struct TSLanguage TSLanguage; typedef struct TSParser TSParser; typedef struct TSTree TSTree; +typedef struct TSQuery TSQuery; +typedef struct TSQueryContext TSQueryContext; typedef enum { TSInputEncodingUTF8, @@ -87,6 +89,18 @@ typedef struct { uint32_t context[2]; } TSTreeCursor; +typedef struct { + TSNode node; + uint32_t index; +} TSQueryCapture; + +typedef enum { + TSQueryErrorNone = 0, + TSQueryErrorSyntax, + TSQueryErrorNodeType, + TSQueryErrorField, +} TSQueryError; + /********************/ /* Section - Parser */ /********************/ @@ -602,6 +616,107 @@ int64_t ts_tree_cursor_goto_first_child_for_byte(TSTreeCursor *, uint32_t); TSTreeCursor ts_tree_cursor_copy(const TSTreeCursor *); +/*******************/ +/* Section - Query */ +/*******************/ + +/** + * Create a new query from a string containing one or more S-expression + * patterns. The query is associated with a particular language, and can + * only be run on syntax nodes parsed with that language. + * + * If all of the given patterns are valid, this returns a `TSQuery`. + * If a pattern is invalid, this returns `NULL`, and provides two pieces + * of information about the problem: + * 1. The byte offset of the error is written to the `error_offset` parameter. + * 2. The type of error is written to the `error_type` parameter. + */ +TSQuery *ts_query_new( + const TSLanguage *language, + const char *source, + uint32_t source_len, + uint32_t *error_offset, + TSQueryError *error_type +); + +/** + * Delete a query, freeing all of the memory that it used. + */ +void ts_query_delete(TSQuery *); + +/* + * Get the number of distinct capture names in the query. + */ +uint32_t ts_query_capture_count(const TSQuery *); + +/* + * Get the name and length of one of the query's capture. Each capture + * is associated with a numeric id based on the order that it appeared + * in the query's source. + */ +const char *ts_query_capture_name_for_id( + const TSQuery *self, + uint32_t index, + uint32_t *length +); + +/* + * Get the numeric id of the capture with the given name. + */ +int ts_query_capture_id_for_name( + const TSQuery *self, + const char *name, + uint32_t length +); + +/* + * Create a new context for executing a given query. + * + * The context stores the state that is needed to iteratively search + * for matches. To use the query context: + * 1. First call `ts_query_context_exec` to start running the query + * on a particular syntax node. + * 2. Then repeatedly call `ts_query_context_next` to iterate over + * the matches. + * 3. For each match, you can call `ts_query_context_matched_pattern_index` + * to determine which pattern matched. You can also call + * `ts_query_context_matched_captures` to determine which nodes + * were captured by which capture names. + * + * If you don't care about finding all of the matches, you can stop calling + * `ts_query_context_next` at any point. And you can start executing the + * query against a different node by calling `ts_query_context_exec` again. + */ +TSQueryContext *ts_query_context_new(const TSQuery *); + +/* + * Delete a query context, freeing all of the memory that it used. + */ +void ts_query_context_delete(TSQueryContext *); + +/* + * Start running a query on a given node. + */ +void ts_query_context_exec(TSQueryContext *, TSNode); + +/* + * Advance to the next match of the currently running query. + */ +bool ts_query_context_next(TSQueryContext *); + +/* + * Check which pattern matched. + */ +uint32_t ts_query_context_matched_pattern_index(const TSQueryContext *); + +/* + * Check which pattern matched. + */ +const TSQueryCapture *ts_query_context_matched_captures( + const TSQueryContext *, + uint32_t * +); + /**********************/ /* Section - Language */ /**********************/ diff --git a/lib/src/bits.h b/lib/src/bits.h new file mode 100644 index 00000000..0caa1d8d --- /dev/null +++ b/lib/src/bits.h @@ -0,0 +1,25 @@ +#ifndef TREE_SITTER_BITS_H_ +#define TREE_SITTER_BITS_H_ + +#include + +#ifdef _WIN32 + +#include + +static inline uint32_t count_leading_zeros(uint32_t x) { + if (x == 0) return 32; + uint32_t result; + _BitScanReverse(&reuslt, x); + return result; +} + +#else + +static inline uint32_t count_leading_zeros(uint32_t x) { + if (x == 0) return 32; + return __builtin_clz(x); +} + +#endif +#endif // TREE_SITTER_BITS_H_ diff --git a/lib/src/lib.c b/lib/src/lib.c index fc5fbc92..900304f0 100644 --- a/lib/src/lib.c +++ b/lib/src/lib.c @@ -12,6 +12,7 @@ #include "./lexer.c" #include "./node.c" #include "./parser.c" +#include "./query.c" #include "./stack.c" #include "./subtree.c" #include "./tree_cursor.c" diff --git a/lib/src/query.c b/lib/src/query.c new file mode 100644 index 00000000..f7836a86 --- /dev/null +++ b/lib/src/query.c @@ -0,0 +1,810 @@ +#include "tree_sitter/api.h" +#include "./alloc.h" +#include "./array.h" +#include "./bits.h" +#include "utf8proc.h" +#include + +/* + * Stream - A sequence of unicode characters derived from a UTF8 string. + * This struct is used in parsing query S-expressions. + */ +typedef struct { + const char *input; + const char *end; + int32_t next; + uint8_t next_size; +} Stream; + +/* + * QueryStep - A step in the process of matching a query. Each node within + * a query S-expression maps to one of these steps. An entire pattern is + * represented as a sequence of these steps. + */ +typedef struct { + TSSymbol symbol; + TSFieldId field; + uint16_t capture_id; + uint8_t depth; + bool field_is_multiple; +} QueryStep; + +/* + * CaptureSlice - The name of a capture, represented as a slice of a + * shared string. + */ +typedef struct { + uint32_t offset; + uint32_t length; +} CaptureSlice; + +/* + * PatternSlice - The set of steps needed to match a particular pattern, + * represented as a slice of a shared array. + */ +typedef struct { + uint16_t step_index; + uint16_t pattern_index; +} PatternSlice; + +/* + * QueryState - The state of an in-progress match of a particular pattern + * in a query. While executing, a QueryContext must keep track of a number + * of possible in-progress matches. Each of those possible matches is + * represented as one of these states. + */ +typedef struct { + uint16_t step_index; + uint16_t pattern_index; + uint16_t start_depth; + uint16_t capture_list_id; + uint16_t capture_count; +} QueryState; + +/* + * CaptureListPool - A collection of *lists* of captures. Each QueryState + * needs to maintain its own list of captures. They are all represented as + * slices of one shared array. The CaptureListPool keeps track of which + * parts of the shared array are currently in use by a QueryState. + */ +typedef struct { + TSQueryCapture *contents; + uint32_t list_size; + uint32_t usage_map; +} CaptureListPool; + +/* + * TSQuery - A tree query, compiled from a string of S-expressions. The query + * itself is immutable. The mutable state used in the process of executing the + * query is stored in a `TSQueryContext`. + */ +struct TSQuery { + Array(QueryStep) steps; + Array(char) capture_data; + Array(CaptureSlice) capture_names; + Array(PatternSlice) pattern_map; + const TSLanguage *language; + uint16_t max_capture_count; + uint16_t wildcard_root_pattern_count; +}; + +/* + * TSQueryContext - A stateful struct used to execute a query on a tree. + */ +struct TSQueryContext { + const TSQuery *query; + TSTreeCursor cursor; + Array(QueryState) states; + Array(QueryState) finished_states; + CaptureListPool capture_list_pool; + bool ascending; + uint32_t depth; +}; + +static const TSQueryError PARENT_DONE = -1; +static const uint8_t PATTERN_DONE_MARKER = UINT8_MAX; +static const uint16_t NONE = UINT16_MAX; +static const TSSymbol WILDCARD_SYMBOL = 0; +static const uint16_t MAX_STATE_COUNT = 32; + +/********** + * Stream + **********/ + +static bool stream_advance(Stream *self) { + if (self->input >= self->end) return false; + self->input += self->next_size; + int size = utf8proc_iterate( + (const uint8_t *)self->input, + self->end - self->input, + &self->next + ); + if (size <= 0) return false; + self->next_size = size; + return true; +} + +static void stream_reset(Stream *self, const char *input) { + self->input = input; + self->next_size = 0; + stream_advance(self); +} + +static Stream stream_new(const char *string, uint32_t length) { + Stream self = { + .next = 0, + .input = string, + .end = string + length, + }; + stream_advance(&self); + return self; +} + +static void stream_skip_whitespace(Stream *stream) { + while (iswspace(stream->next)) stream_advance(stream); +} + +static bool stream_is_ident_start(Stream *stream) { + return iswalpha(stream->next) || stream->next == '_' || stream->next == '-'; +} + +static void stream_scan_identifier(Stream *stream) { + do { + stream_advance(stream); + } while ( + iswalnum(stream->next) || + stream->next == '_' || + stream->next == '-' || + stream->next == '.' + ); +} + +/****************** + * CaptureListPool + ******************/ + +static CaptureListPool capture_list_pool_new(uint16_t list_size) { + return (CaptureListPool) { + .contents = ts_calloc(MAX_STATE_COUNT * list_size, sizeof(TSQueryCapture)), + .list_size = list_size, + .usage_map = UINT32_MAX, + }; +} + +static void capture_list_pool_clear(CaptureListPool *self) { + self->usage_map = UINT32_MAX; +} + +static void capture_list_pool_delete(CaptureListPool *self) { + ts_free(self->contents); +} + +static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id) { + return &self->contents[id * self->list_size]; +} + +static uint16_t capture_list_pool_acquire(CaptureListPool *self) { + uint16_t id = count_leading_zeros(self->usage_map); + if (id == 32) return NONE; + self->usage_map &= ~(1 << id); + return id; +} + +static void capture_list_pool_release(CaptureListPool *self, uint16_t id) { + self->usage_map |= (1 << id); +} + +/********* + * Query + *********/ + +static TSSymbol ts_query_intern_node_name( + const TSQuery *self, + const char *name, + uint32_t length, + TSSymbolType symbol_type +) { + uint32_t symbol_count = ts_language_symbol_count(self->language); + for (TSSymbol i = 0; i < symbol_count; i++) { + if ( + ts_language_symbol_type(self->language, i) == symbol_type && + !strncmp(ts_language_symbol_name(self->language, i), name, length) + ) return i; + } + return 0; +} + +static uint16_t ts_query_intern_capture_name( + TSQuery *self, + const char *name, + uint32_t length +) { + int id = ts_query_capture_id_for_name(self, name, length); + if (id >= 0) { + return (uint16_t)id; + } + + CaptureSlice capture = { + .offset = self->capture_data.size, + .length = length, + }; + array_grow_by(&self->capture_data, length + 1); + memcpy(&self->capture_data.contents[capture.offset], name, length); + self->capture_data.contents[self->capture_data.size - 1] = 0; + array_push(&self->capture_names, capture); + return self->capture_names.size - 1; +} + +static inline bool ts_query__pattern_map_search( + const TSQuery *self, + TSSymbol needle, + uint32_t *result +) { + uint32_t base_index = self->wildcard_root_pattern_count; + uint32_t size = self->pattern_map.size - base_index; + if (size == 0) { + *result = base_index; + return false; + } + while (size > 1) { + uint32_t half_size = size / 2; + uint32_t mid_index = base_index + half_size; + TSSymbol mid_symbol = self->steps.contents[ + self->pattern_map.contents[mid_index].step_index + ].symbol; + if (needle > mid_symbol) base_index = mid_index; + size -= half_size; + } + TSSymbol symbol = self->steps.contents[ + self->pattern_map.contents[base_index].step_index + ].symbol; + if (needle > symbol) { + *result = base_index; + return false; + } else if (needle == symbol) { + *result = base_index; + return true; + } else { + *result = base_index + 1; + return false; + } +} + +static inline void ts_query__pattern_map_insert( + TSQuery *self, + TSSymbol symbol, + uint32_t start_step_index +) { + uint32_t index; + ts_query__pattern_map_search(self, symbol, &index); + array_insert(&self->pattern_map, index, ((PatternSlice) { + .step_index = start_step_index, + .pattern_index = self->pattern_map.size, + })); +} + +static TSQueryError ts_query_parse_pattern( + TSQuery *self, + Stream *stream, + uint32_t depth, + uint32_t *capture_count +) { + uint16_t starting_step_index = self->steps.size; + + if (stream->next == 0) return TSQueryErrorSyntax; + + // Finish the parent S-expression + if (stream->next == ')') { + return PARENT_DONE; + } + + // Parse a parenthesized node expression + else if (stream->next == '(') { + stream_advance(stream); + stream_skip_whitespace(stream); + TSSymbol symbol; + + // Parse the wildcard symbol + if (stream->next == '*') { + symbol = WILDCARD_SYMBOL; + stream_advance(stream); + } + + // Parse a normal node name + else if (stream_is_ident_start(stream)) { + const char *node_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - node_name; + symbol = ts_query_intern_node_name( + self, + node_name, + length, + TSSymbolTypeRegular + ); + if (!symbol) { + stream_reset(stream, node_name); + return TSQueryErrorNodeType; + } + } else { + return TSQueryErrorSyntax; + } + + // Add a step for the node. + array_push(&self->steps, ((QueryStep) { + .depth = depth, + .symbol = symbol, + .field = 0, + .capture_id = NONE, + })); + + // Parse the child patterns + stream_skip_whitespace(stream); + for (;;) { + TSQueryError e = ts_query_parse_pattern(self, stream, depth + 1, capture_count); + if (e == PARENT_DONE) { + stream_advance(stream); + break; + } else if (e) { + return e; + } + } + } + + // Parse a double-quoted anonymous leaf node expression + else if (stream->next == '"') { + stream_advance(stream); + + // Parse the string content + const char *string_content = stream->input; + while (stream->next && stream->next != '"') stream_advance(stream); + uint32_t length = stream->input - string_content; + + // Add a step for the node + TSSymbol symbol = ts_query_intern_node_name( + self, + string_content, + length, + TSSymbolTypeAnonymous + ); + if (!symbol) { + stream_reset(stream, string_content); + return TSQueryErrorNodeType; + } + array_push(&self->steps, ((QueryStep) { + .depth = depth, + .symbol = symbol, + .field = 0, + })); + + if (stream->next != '"') return TSQueryErrorSyntax; + stream_advance(stream); + } + + // Parse a field-prefixed pattern + else if (stream_is_ident_start(stream)) { + // Parse the field name + const char *field_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - field_name; + stream_skip_whitespace(stream); + + if (stream->next != ':') { + stream_reset(stream, field_name); + return TSQueryErrorSyntax; + } + stream_advance(stream); + stream_skip_whitespace(stream); + + // Parse the pattern + uint32_t step_index = self->steps.size; + TSQueryError e = ts_query_parse_pattern(self, stream, depth, capture_count); + if (e == PARENT_DONE) return TSQueryErrorSyntax; + if (e) return e; + + // Add the field name to the first step of the pattern + TSFieldId field_id = ts_language_field_id_for_name( + self->language, + field_name, + length + ); + if (!field_id) { + stream->input = field_name; + return TSQueryErrorField; + } + self->steps.contents[step_index].field = field_id; + } + + // Parse a wildcard pattern + else if (stream->next == '*') { + stream_advance(stream); + stream_skip_whitespace(stream); + + // Add a step that matches any kind of node + array_push(&self->steps, ((QueryStep) { + .depth = depth, + .symbol = WILDCARD_SYMBOL, + .field = 0, + })); + } + + // No match + else { + return TSQueryErrorSyntax; + } + + stream_skip_whitespace(stream); + + // Parse a '@'-suffixed capture pattern + if (stream->next == '@') { + stream_advance(stream); + stream_skip_whitespace(stream); + + // Parse the capture name + if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; + const char *capture_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - capture_name; + + // Add the capture id to the first step of the pattern + uint16_t capture_id = ts_query_intern_capture_name( + self, + capture_name, + length + ); + self->steps.contents[starting_step_index].capture_id = capture_id; + (*capture_count)++; + + stream_skip_whitespace(stream); + } + + return 0; +} + +TSQuery *ts_query_new( + const TSLanguage *language, + const char *source, + uint32_t source_len, + uint32_t *error_offset, + TSQueryError *error_type +) { + TSQuery *self = ts_malloc(sizeof(TSQuery)); + *self = (TSQuery) { + .steps = array_new(), + .pattern_map = array_new(), + .wildcard_root_pattern_count = 0, + .max_capture_count = 0, + .language = language, + }; + + // Parse all of the S-expressions in the given string. + Stream stream = stream_new(source, source_len); + stream_skip_whitespace(&stream); + uint32_t start_step_index; + for (;;) { + start_step_index = self->steps.size; + uint32_t capture_count = 0; + *error_type = ts_query_parse_pattern(self, &stream, 0, &capture_count); + array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER })); + + // If any pattern could not be parsed, then report the error information + // and terminate. + if (*error_type) { + *error_offset = stream.input - source; + ts_query_delete(self); + return NULL; + } + + // Maintain a map that can look up patterns for a given root symbol. + ts_query__pattern_map_insert( + self, + self->steps.contents[start_step_index].symbol, + start_step_index + ); + if (self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL) { + self->wildcard_root_pattern_count++; + } + + if (capture_count > self->max_capture_count) { + self->max_capture_count = capture_count; + } + + if (stream.input == stream.end) break; + } + + return self; +} + +void ts_query_delete(TSQuery *self) { + if (self) { + array_delete(&self->steps); + array_delete(&self->pattern_map); + array_delete(&self->capture_data); + array_delete(&self->capture_names); + ts_free(self); + } +} + +uint32_t ts_query_capture_count(const TSQuery *self) { + return self->capture_names.size; +} + +const char *ts_query_capture_name_for_id( + const TSQuery *self, + uint32_t index, + uint32_t *length +) { + CaptureSlice name = self->capture_names.contents[index]; + *length = name.length; + return &self->capture_data.contents[name.offset]; +} + +int ts_query_capture_id_for_name( + const TSQuery *self, + const char *name, + uint32_t length +) { + for (unsigned i = 0; i < self->capture_names.size; i++) { + CaptureSlice existing = self->capture_names.contents[i]; + if ( + existing.length == length && + !strncmp(&self->capture_data.contents[existing.offset], name, length) + ) return i; + } + return -1; +} + +/*************** + * QueryContext + ***************/ + +TSQueryContext *ts_query_context_new(const TSQuery *query) { + TSQueryContext *self = ts_malloc(sizeof(TSQueryContext)); + *self = (TSQueryContext) { + .query = query, + .ascending = false, + .states = array_new(), + .finished_states = array_new(), + .capture_list_pool = capture_list_pool_new(query->max_capture_count), + }; + return self; +} + +void ts_query_context_delete(TSQueryContext *self) { + array_delete(&self->states); + array_delete(&self->finished_states); + ts_tree_cursor_delete(&self->cursor); + capture_list_pool_delete(&self->capture_list_pool); + ts_free(self); +} + +void ts_query_context_exec(TSQueryContext *self, TSNode node) { + array_clear(&self->states); + array_clear(&self->finished_states); + ts_tree_cursor_reset(&self->cursor, node); + capture_list_pool_clear(&self->capture_list_pool); + self->depth = 0; + self->ascending = false; +} + +bool ts_query_context_next(TSQueryContext *self) { + if (self->finished_states.size > 0) { + array_pop(&self->finished_states); + } + + while (self->finished_states.size == 0) { + if (self->ascending) { + // Remove any states that were started within this node and are still + // not complete. + uint32_t deleted_count = 0; + for (unsigned i = 0, n = self->states.size; i < n; i++) { + QueryState *state = &self->states.contents[i]; + if (state->start_depth == self->depth) { + + // printf("FAIL STATE pattern: %u, step: %u\n", state->pattern_index, state->step_index); + + capture_list_pool_release( + &self->capture_list_pool, + state->capture_list_id + ); + deleted_count++; + } else if (deleted_count > 0) { + self->states.contents[i - deleted_count] = *state; + } + } + + // if (deleted_count) { + // printf("FAILED %u of %u states\n", deleted_count, self->states.size); + // } + + self->states.size -= deleted_count; + + if (ts_tree_cursor_goto_next_sibling(&self->cursor)) { + self->ascending = false; + } else if (ts_tree_cursor_goto_parent(&self->cursor)) { + self->depth--; + } else { + return false; + } + } else { + TSFieldId field_id = NONE; + bool field_occurs_in_later_sibling = false; + TSNode node = ts_tree_cursor_current_node(&self->cursor); + TSSymbol symbol = ts_node_symbol(node); + + // printf("DESCEND INTO NODE: %s\n", ts_node_type(node)); + + // Add new states for any patterns whose root node is a wildcard. + for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { + PatternSlice *slice = &self->query->pattern_map.contents[i]; + QueryStep *step = &self->query->steps.contents[slice->step_index]; + + // Check that the node matches the criteria for the first step + // of the pattern. + if (step->field) { + if (field_id == NONE) { + field_id = ts_tree_cursor_current_field_id_ext( + &self->cursor, + &field_occurs_in_later_sibling + ); + } + if (field_id != step->field) continue; + } + + // Add a new state at the start of this pattern. + uint32_t capture_list_id = capture_list_pool_acquire( + &self->capture_list_pool + ); + if (capture_list_id == NONE) break; + array_push(&self->states, ((QueryState) { + .step_index = slice->step_index, + .pattern_index = slice->pattern_index, + .capture_list_id = capture_list_id, + })); + } + + // Add new states for any patterns whose root node matches this node. + unsigned i; + if (ts_query__pattern_map_search(self->query, symbol, &i)) { + PatternSlice *slice = &self->query->pattern_map.contents[i]; + QueryStep *step = &self->query->steps.contents[slice->step_index]; + do { + if (step->field) { + if (field_id == NONE) { + field_id = ts_tree_cursor_current_field_id_ext( + &self->cursor, + &field_occurs_in_later_sibling + ); + } + if (field_id != step->field) continue; + } + + // printf("START NEW STATE: %u\n", slice->pattern_index); + + // If the node matches the first step of the pattern, then add + // a new in-progress state. First, acquire a list to hold the + // pattern's captures. + uint32_t capture_list_id = capture_list_pool_acquire( + &self->capture_list_pool + ); + if (capture_list_id == NONE) break; + + array_push(&self->states, ((QueryState) { + .pattern_index = slice->pattern_index, + .step_index = slice->step_index + 1, + .start_depth = self->depth, + .capture_list_id = capture_list_id, + .capture_count = 0, + })); + + i++; + if (i == self->query->pattern_map.size) break; + slice = &self->query->pattern_map.contents[i]; + step = &self->query->steps.contents[slice->step_index]; + } while (step->symbol == symbol); + } + + // Update all of the in-progress states with current node. + for (unsigned i = 0, n = self->states.size; i < n; i++) { + QueryState *state = &self->states.contents[i]; + QueryStep *step = &self->query->steps.contents[state->step_index]; + + // Check that the node matches all of the criteria for the next + // step of the pattern. + if (state->start_depth + step->depth != self->depth) continue; + if (step->symbol && step->symbol != symbol) continue; + if (step->field) { + // Only compute the current field if it is needed for the current + // step of some in-progress pattern. + if (field_id == NONE) { + field_id = ts_tree_cursor_current_field_id_ext( + &self->cursor, + &field_occurs_in_later_sibling + ); + } + if (field_id != step->field) continue; + } + + // Some patterns can match their root node in multiple ways, + // capturing different children. If this pattern step could match + // later children within the same parent, then this query state + // cannot simply be updated in place. It must be split into two + // states: one that captures this node, and one which skips over + // this node, to preserve the possibility of capturing later + // siblings. + QueryState *next_state = state; + if (step->depth > 0 && (!step->field || field_occurs_in_later_sibling)) { + uint32_t capture_list_id = capture_list_pool_acquire( + &self->capture_list_pool + ); + if (capture_list_id != NONE) { + array_push(&self->states, *state); + next_state = array_back(&self->states); + next_state->capture_list_id = capture_list_id; + } + } + + // Record captures + if (step->capture_id != NONE) { + // printf("CAPTURE id: %u\n", step->capture_id); + + TSQueryCapture *capture_list = capture_list_pool_get( + &self->capture_list_pool, + next_state->capture_list_id + ); + capture_list[next_state->capture_count++] = (TSQueryCapture) { + node, + step->capture_id + }; + } + + // If the pattern is now done, then populate the query context's + // finished state. + next_state->step_index++; + QueryStep *next_step = step + 1; + if (next_step->depth == PATTERN_DONE_MARKER) { + // printf("FINISHED MATCH pattern: %u\n", next_state->pattern_index); + + array_push(&self->finished_states, *next_state); + if (next_state == state) { + array_erase(&self->states, i); + i--; + n--; + } else { + array_pop(&self->states); + } + } + } + + if (ts_tree_cursor_goto_first_child(&self->cursor)) { + self->depth++; + } else { + self->ascending = true; + } + } + } + + return true; +} + +uint32_t ts_query_context_matched_pattern_index(const TSQueryContext *self) { + if (self->finished_states.size > 0) { + QueryState *state = array_back(&self->finished_states); + return state->pattern_index; + } + return 0; +} + +const TSQueryCapture *ts_query_context_matched_captures( + const TSQueryContext *self, + uint32_t *count +) { + if (self->finished_states.size > 0) { + QueryState *state = array_back(&self->finished_states); + *count = state->capture_count; + return capture_list_pool_get( + (CaptureListPool *)&self->capture_list_pool, + state->capture_list_id + ); + } + return NULL; +} diff --git a/lib/src/tree_cursor.c b/lib/src/tree_cursor.c index ba77ebc0..2ba3f947 100644 --- a/lib/src/tree_cursor.c +++ b/lib/src/tree_cursor.c @@ -244,7 +244,12 @@ TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) { ); } -TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) { +static inline TSFieldId ts_tree_cursor__current_field_info( + const TSTreeCursor *_self, + const TSFieldMapEntry **field_map, + const TSFieldMapEntry **field_map_end, + uint32_t *child_index +) { const TreeCursor *self = (const TreeCursor *)_self; // Walk up the tree, visiting the current node and its invisible ancestors. @@ -264,25 +269,61 @@ TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) { } } - const TSFieldMapEntry *field_map, *field_map_end; + if (ts_subtree_extra(*entry->subtree)) break; + ts_language_field_map( self->tree->language, parent_entry->subtree->ptr->production_id, - &field_map, &field_map_end + field_map, field_map_end ); - - while (field_map < field_map_end) { - if ( - !ts_subtree_extra(*entry->subtree) && - !field_map->inherited && - field_map->child_index == entry->structural_child_index - ) return field_map->field_id; - field_map++; + for (const TSFieldMapEntry *i = *field_map; i < *field_map_end; i++) { + if (!i->inherited && i->child_index == entry->structural_child_index) { + *child_index = entry->structural_child_index; + return i->field_id; + } } } return 0; } +TSFieldId ts_tree_cursor_current_field_id_ext( + const TSTreeCursor *self, + bool *field_has_additional +) { + uint32_t child_index; + const TSFieldMapEntry *field_map, *field_map_end; + TSFieldId field_id = ts_tree_cursor__current_field_info( + self, + &field_map, + &field_map_end, + &child_index + ); + + // After finding the field, check if any other later children have + // the same field name. + if (field_id) { + for (const TSFieldMapEntry *i = field_map; i < field_map_end; i++) { + if (i->field_id == field_id && i->child_index > child_index) { + *field_has_additional = true; + } + } + } + + return field_id; +} + + +TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *self) { + uint32_t child_index; + const TSFieldMapEntry *field_map, *field_map_end; + return ts_tree_cursor__current_field_info( + self, + &field_map, + &field_map_end, + &child_index + ); +} + const char *ts_tree_cursor_current_field_name(const TSTreeCursor *_self) { TSFieldId id = ts_tree_cursor_current_field_id(_self); if (id) { diff --git a/lib/src/tree_cursor.h b/lib/src/tree_cursor.h index 55bdad86..9b438843 100644 --- a/lib/src/tree_cursor.h +++ b/lib/src/tree_cursor.h @@ -16,5 +16,6 @@ typedef struct { } TreeCursor; void ts_tree_cursor_init(TreeCursor *, TSNode); +TSFieldId ts_tree_cursor_current_field_id_ext(const TSTreeCursor *, bool *); #endif // TREE_SITTER_TREE_CURSOR_H_