Start work on an API for querying trees
This commit is contained in:
parent
4151a428ec
commit
fe7c74e7aa
10 changed files with 1430 additions and 12 deletions
|
|
@ -4,4 +4,5 @@ mod highlight_test;
|
|||
mod node_test;
|
||||
mod parser_test;
|
||||
mod properties_test;
|
||||
mod query_test;
|
||||
mod tree_test;
|
||||
|
|
|
|||
216
cli/src/tests/query_test.rs
Normal file
216
cli/src/tests/query_test.rs
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
use super::helpers::allocations;
|
||||
use super::helpers::fixtures::get_language;
|
||||
use tree_sitter::{Parser, Query, QueryError, QueryMatch};
|
||||
|
||||
#[test]
|
||||
fn test_query_errors_on_invalid_syntax() {
|
||||
allocations::start_recording();
|
||||
|
||||
let language = get_language("javascript");
|
||||
|
||||
assert!(Query::new(language, "(if_statement)").is_ok());
|
||||
assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok());
|
||||
|
||||
// Mismatched parens
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement"),
|
||||
Err(QueryError::Syntax(13))
|
||||
);
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement))"),
|
||||
Err(QueryError::Syntax(14))
|
||||
);
|
||||
|
||||
// Return an error at the *beginning* of a bare identifier not followed a colon.
|
||||
// If there's a colon but no pattern, return an error at the end of the colon.
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement identifier)"),
|
||||
Err(QueryError::Syntax(14))
|
||||
);
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement condition:)"),
|
||||
Err(QueryError::Syntax(24))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement condition:)"),
|
||||
Err(QueryError::Syntax(24))
|
||||
);
|
||||
|
||||
allocations::stop_recording();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_errors_on_invalid_symbols() {
|
||||
allocations::start_recording();
|
||||
|
||||
let language = get_language("javascript");
|
||||
|
||||
assert_eq!(
|
||||
Query::new(language, "(non_existent1)"),
|
||||
Err(QueryError::NodeType("non_existent1"))
|
||||
);
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement (non_existent2))"),
|
||||
Err(QueryError::NodeType("non_existent2"))
|
||||
);
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement condition: (non_existent3))"),
|
||||
Err(QueryError::NodeType("non_existent3"))
|
||||
);
|
||||
assert_eq!(
|
||||
Query::new(language, "(if_statement not_a_field: (identifier))"),
|
||||
Err(QueryError::Field("not_a_field"))
|
||||
);
|
||||
|
||||
allocations::stop_recording();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_capture_names() {
|
||||
allocations::start_recording();
|
||||
|
||||
let language = get_language("javascript");
|
||||
let query = Query::new(
|
||||
language,
|
||||
r#"
|
||||
(if_statement
|
||||
condition: (binary_expression
|
||||
left: * @left-operand
|
||||
operator: "||"
|
||||
right: * @right-operand)
|
||||
consequence: (statement_block) @body)
|
||||
|
||||
(while_statement
|
||||
condition:* @loop-condition)
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
query.capture_names(),
|
||||
&[
|
||||
"left-operand".to_string(),
|
||||
"right-operand".to_string(),
|
||||
"body".to_string(),
|
||||
"loop-condition".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
drop(query);
|
||||
allocations::stop_recording();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_exec_with_simple_pattern() {
|
||||
allocations::start_recording();
|
||||
|
||||
let language = get_language("javascript");
|
||||
let query = Query::new(
|
||||
language,
|
||||
"(function_declaration name: (identifier) @fn-name)",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source = "function one() { two(); function three() {} }";
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(language).unwrap();
|
||||
let tree = parser.parse(source, None).unwrap();
|
||||
|
||||
let context = query.context();
|
||||
let matches = context.exec(tree.root_node());
|
||||
|
||||
assert_eq!(
|
||||
collect_matches(matches, &query, source),
|
||||
&[
|
||||
(0, vec![("fn-name", "one")]),
|
||||
(0, vec![("fn-name", "three")])
|
||||
],
|
||||
);
|
||||
|
||||
drop(context);
|
||||
drop(parser);
|
||||
drop(query);
|
||||
drop(tree);
|
||||
allocations::stop_recording();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_exec_with_multiple_matches_same_root() {
|
||||
allocations::start_recording();
|
||||
|
||||
let language = get_language("javascript");
|
||||
let query = Query::new(
|
||||
language,
|
||||
"(class_declaration
|
||||
name: (identifier) @the-class-name
|
||||
(class_body
|
||||
(method_definition
|
||||
name: (property_identifier) @the-method-name)))",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source = "
|
||||
class Person {
|
||||
// the constructor
|
||||
constructor(name) { this.name = name; }
|
||||
|
||||
// the getter
|
||||
getFullName() { return this.name; }
|
||||
}
|
||||
";
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(language).unwrap();
|
||||
let tree = parser.parse(source, None).unwrap();
|
||||
let context = query.context();
|
||||
let matches = context.exec(tree.root_node());
|
||||
|
||||
assert_eq!(
|
||||
collect_matches(matches, &query, source),
|
||||
&[
|
||||
(
|
||||
0,
|
||||
vec![
|
||||
("the-class-name", "Person"),
|
||||
("the-method-name", "constructor")
|
||||
]
|
||||
),
|
||||
(
|
||||
0,
|
||||
vec![
|
||||
("the-class-name", "Person"),
|
||||
("the-method-name", "getFullName")
|
||||
]
|
||||
),
|
||||
],
|
||||
);
|
||||
|
||||
drop(context);
|
||||
drop(parser);
|
||||
drop(query);
|
||||
drop(tree);
|
||||
allocations::stop_recording();
|
||||
}
|
||||
|
||||
fn collect_matches<'a>(
|
||||
matches: impl Iterator<Item = QueryMatch<'a>>,
|
||||
query: &'a Query,
|
||||
source: &'a str,
|
||||
) -> Vec<(usize, Vec<(&'a str, &'a str)>)> {
|
||||
matches
|
||||
.map(|m| {
|
||||
(
|
||||
m.pattern_index(),
|
||||
m.captures()
|
||||
.map(|(capture_id, node)| {
|
||||
(
|
||||
query.capture_names()[capture_id].as_str(),
|
||||
node.utf8_text(source.as_bytes()).unwrap(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
|
@ -19,6 +19,16 @@ pub struct TSParser {
|
|||
pub struct TSTree {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct TSQuery {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct TSQueryContext {
|
||||
_unused: [u8; 0],
|
||||
}
|
||||
pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0;
|
||||
pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1;
|
||||
pub type TSInputEncoding = u32;
|
||||
|
|
@ -93,6 +103,17 @@ pub struct TSTreeCursor {
|
|||
pub id: *const ::std::os::raw::c_void,
|
||||
pub context: [u32; 2usize],
|
||||
}
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct TSQueryCapture {
|
||||
pub node: TSNode,
|
||||
pub index: u32,
|
||||
}
|
||||
pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0;
|
||||
pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1;
|
||||
pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
|
||||
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
|
||||
pub type TSQueryError = u32;
|
||||
extern "C" {
|
||||
#[doc = " Create a new parser."]
|
||||
pub fn ts_parser_new() -> *mut TSParser;
|
||||
|
|
@ -538,6 +559,65 @@ extern "C" {
|
|||
extern "C" {
|
||||
pub fn ts_tree_cursor_copy(arg1: *const TSTreeCursor) -> TSTreeCursor;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Create a new query based on a given language and string containing"]
|
||||
#[doc = " one or more S-expression patterns."]
|
||||
#[doc = ""]
|
||||
#[doc = " If all of the given patterns are valid, this returns a `TSQuery`."]
|
||||
#[doc = " If a pattern is invalid, this returns `NULL`, and provides two pieces"]
|
||||
#[doc = " of information about the problem:"]
|
||||
#[doc = " 1. The byte offset of the error is written to the `error_offset` parameter."]
|
||||
#[doc = " 2. The type of error is written to the `error_type` parameter."]
|
||||
pub fn ts_query_new(
|
||||
arg1: *const TSLanguage,
|
||||
source: *const ::std::os::raw::c_char,
|
||||
source_len: u32,
|
||||
error_offset: *mut u32,
|
||||
error_type: *mut TSQueryError,
|
||||
) -> *mut TSQuery;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Delete a query, freeing all of the memory that it used."]
|
||||
pub fn ts_query_delete(arg1: *mut TSQuery);
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32;
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_capture_name_for_id(
|
||||
self_: *const TSQuery,
|
||||
index: u32,
|
||||
length: *mut u32,
|
||||
) -> *const ::std::os::raw::c_char;
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_capture_id_for_name(
|
||||
self_: *const TSQuery,
|
||||
name: *const ::std::os::raw::c_char,
|
||||
length: u32,
|
||||
) -> ::std::os::raw::c_int;
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_context_new(arg1: *const TSQuery) -> *mut TSQueryContext;
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_context_delete(arg1: *mut TSQueryContext);
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_context_exec(arg1: *mut TSQueryContext, arg2: TSNode);
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_context_next(arg1: *mut TSQueryContext) -> bool;
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_context_matched_pattern_index(arg1: *const TSQueryContext) -> u32;
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_context_matched_captures(
|
||||
arg1: *const TSQueryContext,
|
||||
arg2: *mut u32,
|
||||
) -> *const TSQueryCapture;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the number of distinct node types in the language."]
|
||||
pub fn ts_language_symbol_count(arg1: *const TSLanguage) -> u32;
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use std::ffi::CStr;
|
|||
use std::marker::PhantomData;
|
||||
use std::os::raw::{c_char, c_void};
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::{fmt, ptr, str, u16};
|
||||
use std::{char, fmt, ptr, slice, str, u16};
|
||||
|
||||
pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION;
|
||||
pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h");
|
||||
|
|
@ -136,6 +136,23 @@ pub struct TreePropertyCursor<'a, P> {
|
|||
source: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Query {
|
||||
ptr: *mut ffi::TSQuery,
|
||||
capture_names: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct QueryContext<'a>(*mut ffi::TSQueryContext, PhantomData<&'a ()>);
|
||||
|
||||
pub struct QueryMatch<'a>(&'a QueryContext<'a>);
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum QueryError<'a> {
|
||||
Syntax(usize),
|
||||
NodeType(&'a str),
|
||||
Field(&'a str),
|
||||
}
|
||||
|
||||
impl Language {
|
||||
pub fn version(&self) -> usize {
|
||||
unsafe { ffi::ts_language_version(self.0) as usize }
|
||||
|
|
@ -921,6 +938,117 @@ impl<'a, P> TreePropertyCursor<'a, P> {
|
|||
}
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn new(language: Language, source: &str) -> Result<Self, QueryError> {
|
||||
let mut error_offset = 0u32;
|
||||
let mut error_type: ffi::TSQueryError = 0;
|
||||
let bytes = source.as_bytes();
|
||||
let ptr = unsafe {
|
||||
ffi::ts_query_new(
|
||||
language.0,
|
||||
bytes.as_ptr() as *const c_char,
|
||||
bytes.len() as u32,
|
||||
&mut error_offset as *mut u32,
|
||||
&mut error_type as *mut ffi::TSQueryError,
|
||||
)
|
||||
};
|
||||
if ptr.is_null() {
|
||||
let offset = error_offset as usize;
|
||||
Err(match error_type {
|
||||
ffi::TSQueryError_TSQueryErrorNodeType | ffi::TSQueryError_TSQueryErrorField => {
|
||||
let suffix = source.split_at(offset).1;
|
||||
let end_offset = suffix
|
||||
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
|
||||
.unwrap_or(source.len());
|
||||
let name = suffix.split_at(end_offset).0;
|
||||
if error_type == ffi::TSQueryError_TSQueryErrorNodeType {
|
||||
QueryError::NodeType(name)
|
||||
} else {
|
||||
QueryError::Field(name)
|
||||
}
|
||||
}
|
||||
_ => QueryError::Syntax(offset),
|
||||
})
|
||||
} else {
|
||||
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
|
||||
let capture_names = (0..capture_count)
|
||||
.map(|i| unsafe {
|
||||
let mut length = 0u32;
|
||||
let name =
|
||||
ffi::ts_query_capture_name_for_id(ptr, i as u32, &mut length as *mut u32)
|
||||
as *const u8;
|
||||
let name = slice::from_raw_parts(name, length as usize);
|
||||
let name = str::from_utf8_unchecked(name);
|
||||
name.to_string()
|
||||
})
|
||||
.collect();
|
||||
Ok(Query { ptr, capture_names })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn capture_names(&self) -> &[String] {
|
||||
&self.capture_names
|
||||
}
|
||||
|
||||
pub fn context(&self) -> QueryContext {
|
||||
let context = unsafe { ffi::ts_query_context_new(self.ptr) };
|
||||
QueryContext(context, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> QueryContext<'a> {
|
||||
pub fn exec(&'a self, node: Node<'a>) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
|
||||
unsafe {
|
||||
ffi::ts_query_context_exec(self.0, node.0);
|
||||
}
|
||||
std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
|
||||
unsafe {
|
||||
if ffi::ts_query_context_next(self.0) {
|
||||
Some(QueryMatch(self))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> QueryMatch<'a> {
|
||||
pub fn pattern_index(&self) -> usize {
|
||||
unsafe { ffi::ts_query_context_matched_pattern_index((self.0).0) as usize }
|
||||
}
|
||||
|
||||
pub fn captures(&self) -> impl ExactSizeIterator<Item = (usize, Node)> {
|
||||
unsafe {
|
||||
let mut capture_count = 0u32;
|
||||
let captures =
|
||||
ffi::ts_query_context_matched_captures((self.0).0, &mut capture_count as *mut u32);
|
||||
let captures = slice::from_raw_parts(captures, capture_count as usize);
|
||||
captures
|
||||
.iter()
|
||||
.map(move |capture| (capture.index as usize, Node::new(capture.node).unwrap()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Query {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.ptr == other.ptr
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Query {
|
||||
fn drop(&mut self) {
|
||||
unsafe { ffi::ts_query_delete(self.ptr) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for QueryContext<'a> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { ffi::ts_query_context_delete(self.0) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Point {
|
||||
pub fn new(row: usize, column: usize) -> Self {
|
||||
Point { row, column }
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ typedef uint16_t TSFieldId;
|
|||
typedef struct TSLanguage TSLanguage;
|
||||
typedef struct TSParser TSParser;
|
||||
typedef struct TSTree TSTree;
|
||||
typedef struct TSQuery TSQuery;
|
||||
typedef struct TSQueryContext TSQueryContext;
|
||||
|
||||
typedef enum {
|
||||
TSInputEncodingUTF8,
|
||||
|
|
@ -87,6 +89,18 @@ typedef struct {
|
|||
uint32_t context[2];
|
||||
} TSTreeCursor;
|
||||
|
||||
typedef struct {
|
||||
TSNode node;
|
||||
uint32_t index;
|
||||
} TSQueryCapture;
|
||||
|
||||
typedef enum {
|
||||
TSQueryErrorNone = 0,
|
||||
TSQueryErrorSyntax,
|
||||
TSQueryErrorNodeType,
|
||||
TSQueryErrorField,
|
||||
} TSQueryError;
|
||||
|
||||
/********************/
|
||||
/* Section - Parser */
|
||||
/********************/
|
||||
|
|
@ -602,6 +616,107 @@ int64_t ts_tree_cursor_goto_first_child_for_byte(TSTreeCursor *, uint32_t);
|
|||
|
||||
TSTreeCursor ts_tree_cursor_copy(const TSTreeCursor *);
|
||||
|
||||
/*******************/
|
||||
/* Section - Query */
|
||||
/*******************/
|
||||
|
||||
/**
|
||||
* Create a new query from a string containing one or more S-expression
|
||||
* patterns. The query is associated with a particular language, and can
|
||||
* only be run on syntax nodes parsed with that language.
|
||||
*
|
||||
* If all of the given patterns are valid, this returns a `TSQuery`.
|
||||
* If a pattern is invalid, this returns `NULL`, and provides two pieces
|
||||
* of information about the problem:
|
||||
* 1. The byte offset of the error is written to the `error_offset` parameter.
|
||||
* 2. The type of error is written to the `error_type` parameter.
|
||||
*/
|
||||
TSQuery *ts_query_new(
|
||||
const TSLanguage *language,
|
||||
const char *source,
|
||||
uint32_t source_len,
|
||||
uint32_t *error_offset,
|
||||
TSQueryError *error_type
|
||||
);
|
||||
|
||||
/**
|
||||
* Delete a query, freeing all of the memory that it used.
|
||||
*/
|
||||
void ts_query_delete(TSQuery *);
|
||||
|
||||
/*
|
||||
* Get the number of distinct capture names in the query.
|
||||
*/
|
||||
uint32_t ts_query_capture_count(const TSQuery *);
|
||||
|
||||
/*
|
||||
* Get the name and length of one of the query's capture. Each capture
|
||||
* is associated with a numeric id based on the order that it appeared
|
||||
* in the query's source.
|
||||
*/
|
||||
const char *ts_query_capture_name_for_id(
|
||||
const TSQuery *self,
|
||||
uint32_t index,
|
||||
uint32_t *length
|
||||
);
|
||||
|
||||
/*
|
||||
* Get the numeric id of the capture with the given name.
|
||||
*/
|
||||
int ts_query_capture_id_for_name(
|
||||
const TSQuery *self,
|
||||
const char *name,
|
||||
uint32_t length
|
||||
);
|
||||
|
||||
/*
|
||||
* Create a new context for executing a given query.
|
||||
*
|
||||
* The context stores the state that is needed to iteratively search
|
||||
* for matches. To use the query context:
|
||||
* 1. First call `ts_query_context_exec` to start running the query
|
||||
* on a particular syntax node.
|
||||
* 2. Then repeatedly call `ts_query_context_next` to iterate over
|
||||
* the matches.
|
||||
* 3. For each match, you can call `ts_query_context_matched_pattern_index`
|
||||
* to determine which pattern matched. You can also call
|
||||
* `ts_query_context_matched_captures` to determine which nodes
|
||||
* were captured by which capture names.
|
||||
*
|
||||
* If you don't care about finding all of the matches, you can stop calling
|
||||
* `ts_query_context_next` at any point. And you can start executing the
|
||||
* query against a different node by calling `ts_query_context_exec` again.
|
||||
*/
|
||||
TSQueryContext *ts_query_context_new(const TSQuery *);
|
||||
|
||||
/*
|
||||
* Delete a query context, freeing all of the memory that it used.
|
||||
*/
|
||||
void ts_query_context_delete(TSQueryContext *);
|
||||
|
||||
/*
|
||||
* Start running a query on a given node.
|
||||
*/
|
||||
void ts_query_context_exec(TSQueryContext *, TSNode);
|
||||
|
||||
/*
|
||||
* Advance to the next match of the currently running query.
|
||||
*/
|
||||
bool ts_query_context_next(TSQueryContext *);
|
||||
|
||||
/*
|
||||
* Check which pattern matched.
|
||||
*/
|
||||
uint32_t ts_query_context_matched_pattern_index(const TSQueryContext *);
|
||||
|
||||
/*
|
||||
* Check which pattern matched.
|
||||
*/
|
||||
const TSQueryCapture *ts_query_context_matched_captures(
|
||||
const TSQueryContext *,
|
||||
uint32_t *
|
||||
);
|
||||
|
||||
/**********************/
|
||||
/* Section - Language */
|
||||
/**********************/
|
||||
|
|
|
|||
25
lib/src/bits.h
Normal file
25
lib/src/bits.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#ifndef TREE_SITTER_BITS_H_
|
||||
#define TREE_SITTER_BITS_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
#include <intrin.h>
|
||||
|
||||
static inline uint32_t count_leading_zeros(uint32_t x) {
|
||||
if (x == 0) return 32;
|
||||
uint32_t result;
|
||||
_BitScanReverse(&reuslt, x);
|
||||
return result;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static inline uint32_t count_leading_zeros(uint32_t x) {
|
||||
if (x == 0) return 32;
|
||||
return __builtin_clz(x);
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // TREE_SITTER_BITS_H_
|
||||
|
|
@ -12,6 +12,7 @@
|
|||
#include "./lexer.c"
|
||||
#include "./node.c"
|
||||
#include "./parser.c"
|
||||
#include "./query.c"
|
||||
#include "./stack.c"
|
||||
#include "./subtree.c"
|
||||
#include "./tree_cursor.c"
|
||||
|
|
|
|||
810
lib/src/query.c
Normal file
810
lib/src/query.c
Normal file
|
|
@ -0,0 +1,810 @@
|
|||
#include "tree_sitter/api.h"
|
||||
#include "./alloc.h"
|
||||
#include "./array.h"
|
||||
#include "./bits.h"
|
||||
#include "utf8proc.h"
|
||||
#include <wctype.h>
|
||||
|
||||
/*
|
||||
* Stream - A sequence of unicode characters derived from a UTF8 string.
|
||||
* This struct is used in parsing query S-expressions.
|
||||
*/
|
||||
typedef struct {
|
||||
const char *input;
|
||||
const char *end;
|
||||
int32_t next;
|
||||
uint8_t next_size;
|
||||
} Stream;
|
||||
|
||||
/*
|
||||
* QueryStep - A step in the process of matching a query. Each node within
|
||||
* a query S-expression maps to one of these steps. An entire pattern is
|
||||
* represented as a sequence of these steps.
|
||||
*/
|
||||
typedef struct {
|
||||
TSSymbol symbol;
|
||||
TSFieldId field;
|
||||
uint16_t capture_id;
|
||||
uint8_t depth;
|
||||
bool field_is_multiple;
|
||||
} QueryStep;
|
||||
|
||||
/*
|
||||
* CaptureSlice - The name of a capture, represented as a slice of a
|
||||
* shared string.
|
||||
*/
|
||||
typedef struct {
|
||||
uint32_t offset;
|
||||
uint32_t length;
|
||||
} CaptureSlice;
|
||||
|
||||
/*
|
||||
* PatternSlice - The set of steps needed to match a particular pattern,
|
||||
* represented as a slice of a shared array.
|
||||
*/
|
||||
typedef struct {
|
||||
uint16_t step_index;
|
||||
uint16_t pattern_index;
|
||||
} PatternSlice;
|
||||
|
||||
/*
|
||||
* QueryState - The state of an in-progress match of a particular pattern
|
||||
* in a query. While executing, a QueryContext must keep track of a number
|
||||
* of possible in-progress matches. Each of those possible matches is
|
||||
* represented as one of these states.
|
||||
*/
|
||||
typedef struct {
|
||||
uint16_t step_index;
|
||||
uint16_t pattern_index;
|
||||
uint16_t start_depth;
|
||||
uint16_t capture_list_id;
|
||||
uint16_t capture_count;
|
||||
} QueryState;
|
||||
|
||||
/*
|
||||
* CaptureListPool - A collection of *lists* of captures. Each QueryState
|
||||
* needs to maintain its own list of captures. They are all represented as
|
||||
* slices of one shared array. The CaptureListPool keeps track of which
|
||||
* parts of the shared array are currently in use by a QueryState.
|
||||
*/
|
||||
typedef struct {
|
||||
TSQueryCapture *contents;
|
||||
uint32_t list_size;
|
||||
uint32_t usage_map;
|
||||
} CaptureListPool;
|
||||
|
||||
/*
|
||||
* TSQuery - A tree query, compiled from a string of S-expressions. The query
|
||||
* itself is immutable. The mutable state used in the process of executing the
|
||||
* query is stored in a `TSQueryContext`.
|
||||
*/
|
||||
struct TSQuery {
|
||||
Array(QueryStep) steps;
|
||||
Array(char) capture_data;
|
||||
Array(CaptureSlice) capture_names;
|
||||
Array(PatternSlice) pattern_map;
|
||||
const TSLanguage *language;
|
||||
uint16_t max_capture_count;
|
||||
uint16_t wildcard_root_pattern_count;
|
||||
};
|
||||
|
||||
/*
|
||||
* TSQueryContext - A stateful struct used to execute a query on a tree.
|
||||
*/
|
||||
struct TSQueryContext {
|
||||
const TSQuery *query;
|
||||
TSTreeCursor cursor;
|
||||
Array(QueryState) states;
|
||||
Array(QueryState) finished_states;
|
||||
CaptureListPool capture_list_pool;
|
||||
bool ascending;
|
||||
uint32_t depth;
|
||||
};
|
||||
|
||||
static const TSQueryError PARENT_DONE = -1;
|
||||
static const uint8_t PATTERN_DONE_MARKER = UINT8_MAX;
|
||||
static const uint16_t NONE = UINT16_MAX;
|
||||
static const TSSymbol WILDCARD_SYMBOL = 0;
|
||||
static const uint16_t MAX_STATE_COUNT = 32;
|
||||
|
||||
/**********
|
||||
* Stream
|
||||
**********/
|
||||
|
||||
static bool stream_advance(Stream *self) {
|
||||
if (self->input >= self->end) return false;
|
||||
self->input += self->next_size;
|
||||
int size = utf8proc_iterate(
|
||||
(const uint8_t *)self->input,
|
||||
self->end - self->input,
|
||||
&self->next
|
||||
);
|
||||
if (size <= 0) return false;
|
||||
self->next_size = size;
|
||||
return true;
|
||||
}
|
||||
|
||||
static void stream_reset(Stream *self, const char *input) {
|
||||
self->input = input;
|
||||
self->next_size = 0;
|
||||
stream_advance(self);
|
||||
}
|
||||
|
||||
static Stream stream_new(const char *string, uint32_t length) {
|
||||
Stream self = {
|
||||
.next = 0,
|
||||
.input = string,
|
||||
.end = string + length,
|
||||
};
|
||||
stream_advance(&self);
|
||||
return self;
|
||||
}
|
||||
|
||||
static void stream_skip_whitespace(Stream *stream) {
|
||||
while (iswspace(stream->next)) stream_advance(stream);
|
||||
}
|
||||
|
||||
static bool stream_is_ident_start(Stream *stream) {
|
||||
return iswalpha(stream->next) || stream->next == '_' || stream->next == '-';
|
||||
}
|
||||
|
||||
static void stream_scan_identifier(Stream *stream) {
|
||||
do {
|
||||
stream_advance(stream);
|
||||
} while (
|
||||
iswalnum(stream->next) ||
|
||||
stream->next == '_' ||
|
||||
stream->next == '-' ||
|
||||
stream->next == '.'
|
||||
);
|
||||
}
|
||||
|
||||
/******************
|
||||
* CaptureListPool
|
||||
******************/
|
||||
|
||||
static CaptureListPool capture_list_pool_new(uint16_t list_size) {
|
||||
return (CaptureListPool) {
|
||||
.contents = ts_calloc(MAX_STATE_COUNT * list_size, sizeof(TSQueryCapture)),
|
||||
.list_size = list_size,
|
||||
.usage_map = UINT32_MAX,
|
||||
};
|
||||
}
|
||||
|
||||
static void capture_list_pool_clear(CaptureListPool *self) {
|
||||
self->usage_map = UINT32_MAX;
|
||||
}
|
||||
|
||||
static void capture_list_pool_delete(CaptureListPool *self) {
|
||||
ts_free(self->contents);
|
||||
}
|
||||
|
||||
static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id) {
|
||||
return &self->contents[id * self->list_size];
|
||||
}
|
||||
|
||||
static uint16_t capture_list_pool_acquire(CaptureListPool *self) {
|
||||
uint16_t id = count_leading_zeros(self->usage_map);
|
||||
if (id == 32) return NONE;
|
||||
self->usage_map &= ~(1 << id);
|
||||
return id;
|
||||
}
|
||||
|
||||
static void capture_list_pool_release(CaptureListPool *self, uint16_t id) {
|
||||
self->usage_map |= (1 << id);
|
||||
}
|
||||
|
||||
/*********
|
||||
* Query
|
||||
*********/
|
||||
|
||||
static TSSymbol ts_query_intern_node_name(
|
||||
const TSQuery *self,
|
||||
const char *name,
|
||||
uint32_t length,
|
||||
TSSymbolType symbol_type
|
||||
) {
|
||||
uint32_t symbol_count = ts_language_symbol_count(self->language);
|
||||
for (TSSymbol i = 0; i < symbol_count; i++) {
|
||||
if (
|
||||
ts_language_symbol_type(self->language, i) == symbol_type &&
|
||||
!strncmp(ts_language_symbol_name(self->language, i), name, length)
|
||||
) return i;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static uint16_t ts_query_intern_capture_name(
|
||||
TSQuery *self,
|
||||
const char *name,
|
||||
uint32_t length
|
||||
) {
|
||||
int id = ts_query_capture_id_for_name(self, name, length);
|
||||
if (id >= 0) {
|
||||
return (uint16_t)id;
|
||||
}
|
||||
|
||||
CaptureSlice capture = {
|
||||
.offset = self->capture_data.size,
|
||||
.length = length,
|
||||
};
|
||||
array_grow_by(&self->capture_data, length + 1);
|
||||
memcpy(&self->capture_data.contents[capture.offset], name, length);
|
||||
self->capture_data.contents[self->capture_data.size - 1] = 0;
|
||||
array_push(&self->capture_names, capture);
|
||||
return self->capture_names.size - 1;
|
||||
}
|
||||
|
||||
static inline bool ts_query__pattern_map_search(
|
||||
const TSQuery *self,
|
||||
TSSymbol needle,
|
||||
uint32_t *result
|
||||
) {
|
||||
uint32_t base_index = self->wildcard_root_pattern_count;
|
||||
uint32_t size = self->pattern_map.size - base_index;
|
||||
if (size == 0) {
|
||||
*result = base_index;
|
||||
return false;
|
||||
}
|
||||
while (size > 1) {
|
||||
uint32_t half_size = size / 2;
|
||||
uint32_t mid_index = base_index + half_size;
|
||||
TSSymbol mid_symbol = self->steps.contents[
|
||||
self->pattern_map.contents[mid_index].step_index
|
||||
].symbol;
|
||||
if (needle > mid_symbol) base_index = mid_index;
|
||||
size -= half_size;
|
||||
}
|
||||
TSSymbol symbol = self->steps.contents[
|
||||
self->pattern_map.contents[base_index].step_index
|
||||
].symbol;
|
||||
if (needle > symbol) {
|
||||
*result = base_index;
|
||||
return false;
|
||||
} else if (needle == symbol) {
|
||||
*result = base_index;
|
||||
return true;
|
||||
} else {
|
||||
*result = base_index + 1;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void ts_query__pattern_map_insert(
|
||||
TSQuery *self,
|
||||
TSSymbol symbol,
|
||||
uint32_t start_step_index
|
||||
) {
|
||||
uint32_t index;
|
||||
ts_query__pattern_map_search(self, symbol, &index);
|
||||
array_insert(&self->pattern_map, index, ((PatternSlice) {
|
||||
.step_index = start_step_index,
|
||||
.pattern_index = self->pattern_map.size,
|
||||
}));
|
||||
}
|
||||
|
||||
static TSQueryError ts_query_parse_pattern(
|
||||
TSQuery *self,
|
||||
Stream *stream,
|
||||
uint32_t depth,
|
||||
uint32_t *capture_count
|
||||
) {
|
||||
uint16_t starting_step_index = self->steps.size;
|
||||
|
||||
if (stream->next == 0) return TSQueryErrorSyntax;
|
||||
|
||||
// Finish the parent S-expression
|
||||
if (stream->next == ')') {
|
||||
return PARENT_DONE;
|
||||
}
|
||||
|
||||
// Parse a parenthesized node expression
|
||||
else if (stream->next == '(') {
|
||||
stream_advance(stream);
|
||||
stream_skip_whitespace(stream);
|
||||
TSSymbol symbol;
|
||||
|
||||
// Parse the wildcard symbol
|
||||
if (stream->next == '*') {
|
||||
symbol = WILDCARD_SYMBOL;
|
||||
stream_advance(stream);
|
||||
}
|
||||
|
||||
// Parse a normal node name
|
||||
else if (stream_is_ident_start(stream)) {
|
||||
const char *node_name = stream->input;
|
||||
stream_scan_identifier(stream);
|
||||
uint32_t length = stream->input - node_name;
|
||||
symbol = ts_query_intern_node_name(
|
||||
self,
|
||||
node_name,
|
||||
length,
|
||||
TSSymbolTypeRegular
|
||||
);
|
||||
if (!symbol) {
|
||||
stream_reset(stream, node_name);
|
||||
return TSQueryErrorNodeType;
|
||||
}
|
||||
} else {
|
||||
return TSQueryErrorSyntax;
|
||||
}
|
||||
|
||||
// Add a step for the node.
|
||||
array_push(&self->steps, ((QueryStep) {
|
||||
.depth = depth,
|
||||
.symbol = symbol,
|
||||
.field = 0,
|
||||
.capture_id = NONE,
|
||||
}));
|
||||
|
||||
// Parse the child patterns
|
||||
stream_skip_whitespace(stream);
|
||||
for (;;) {
|
||||
TSQueryError e = ts_query_parse_pattern(self, stream, depth + 1, capture_count);
|
||||
if (e == PARENT_DONE) {
|
||||
stream_advance(stream);
|
||||
break;
|
||||
} else if (e) {
|
||||
return e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse a double-quoted anonymous leaf node expression
|
||||
else if (stream->next == '"') {
|
||||
stream_advance(stream);
|
||||
|
||||
// Parse the string content
|
||||
const char *string_content = stream->input;
|
||||
while (stream->next && stream->next != '"') stream_advance(stream);
|
||||
uint32_t length = stream->input - string_content;
|
||||
|
||||
// Add a step for the node
|
||||
TSSymbol symbol = ts_query_intern_node_name(
|
||||
self,
|
||||
string_content,
|
||||
length,
|
||||
TSSymbolTypeAnonymous
|
||||
);
|
||||
if (!symbol) {
|
||||
stream_reset(stream, string_content);
|
||||
return TSQueryErrorNodeType;
|
||||
}
|
||||
array_push(&self->steps, ((QueryStep) {
|
||||
.depth = depth,
|
||||
.symbol = symbol,
|
||||
.field = 0,
|
||||
}));
|
||||
|
||||
if (stream->next != '"') return TSQueryErrorSyntax;
|
||||
stream_advance(stream);
|
||||
}
|
||||
|
||||
// Parse a field-prefixed pattern
|
||||
else if (stream_is_ident_start(stream)) {
|
||||
// Parse the field name
|
||||
const char *field_name = stream->input;
|
||||
stream_scan_identifier(stream);
|
||||
uint32_t length = stream->input - field_name;
|
||||
stream_skip_whitespace(stream);
|
||||
|
||||
if (stream->next != ':') {
|
||||
stream_reset(stream, field_name);
|
||||
return TSQueryErrorSyntax;
|
||||
}
|
||||
stream_advance(stream);
|
||||
stream_skip_whitespace(stream);
|
||||
|
||||
// Parse the pattern
|
||||
uint32_t step_index = self->steps.size;
|
||||
TSQueryError e = ts_query_parse_pattern(self, stream, depth, capture_count);
|
||||
if (e == PARENT_DONE) return TSQueryErrorSyntax;
|
||||
if (e) return e;
|
||||
|
||||
// Add the field name to the first step of the pattern
|
||||
TSFieldId field_id = ts_language_field_id_for_name(
|
||||
self->language,
|
||||
field_name,
|
||||
length
|
||||
);
|
||||
if (!field_id) {
|
||||
stream->input = field_name;
|
||||
return TSQueryErrorField;
|
||||
}
|
||||
self->steps.contents[step_index].field = field_id;
|
||||
}
|
||||
|
||||
// Parse a wildcard pattern
|
||||
else if (stream->next == '*') {
|
||||
stream_advance(stream);
|
||||
stream_skip_whitespace(stream);
|
||||
|
||||
// Add a step that matches any kind of node
|
||||
array_push(&self->steps, ((QueryStep) {
|
||||
.depth = depth,
|
||||
.symbol = WILDCARD_SYMBOL,
|
||||
.field = 0,
|
||||
}));
|
||||
}
|
||||
|
||||
// No match
|
||||
else {
|
||||
return TSQueryErrorSyntax;
|
||||
}
|
||||
|
||||
stream_skip_whitespace(stream);
|
||||
|
||||
// Parse a '@'-suffixed capture pattern
|
||||
if (stream->next == '@') {
|
||||
stream_advance(stream);
|
||||
stream_skip_whitespace(stream);
|
||||
|
||||
// Parse the capture name
|
||||
if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax;
|
||||
const char *capture_name = stream->input;
|
||||
stream_scan_identifier(stream);
|
||||
uint32_t length = stream->input - capture_name;
|
||||
|
||||
// Add the capture id to the first step of the pattern
|
||||
uint16_t capture_id = ts_query_intern_capture_name(
|
||||
self,
|
||||
capture_name,
|
||||
length
|
||||
);
|
||||
self->steps.contents[starting_step_index].capture_id = capture_id;
|
||||
(*capture_count)++;
|
||||
|
||||
stream_skip_whitespace(stream);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
TSQuery *ts_query_new(
|
||||
const TSLanguage *language,
|
||||
const char *source,
|
||||
uint32_t source_len,
|
||||
uint32_t *error_offset,
|
||||
TSQueryError *error_type
|
||||
) {
|
||||
TSQuery *self = ts_malloc(sizeof(TSQuery));
|
||||
*self = (TSQuery) {
|
||||
.steps = array_new(),
|
||||
.pattern_map = array_new(),
|
||||
.wildcard_root_pattern_count = 0,
|
||||
.max_capture_count = 0,
|
||||
.language = language,
|
||||
};
|
||||
|
||||
// Parse all of the S-expressions in the given string.
|
||||
Stream stream = stream_new(source, source_len);
|
||||
stream_skip_whitespace(&stream);
|
||||
uint32_t start_step_index;
|
||||
for (;;) {
|
||||
start_step_index = self->steps.size;
|
||||
uint32_t capture_count = 0;
|
||||
*error_type = ts_query_parse_pattern(self, &stream, 0, &capture_count);
|
||||
array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER }));
|
||||
|
||||
// If any pattern could not be parsed, then report the error information
|
||||
// and terminate.
|
||||
if (*error_type) {
|
||||
*error_offset = stream.input - source;
|
||||
ts_query_delete(self);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Maintain a map that can look up patterns for a given root symbol.
|
||||
ts_query__pattern_map_insert(
|
||||
self,
|
||||
self->steps.contents[start_step_index].symbol,
|
||||
start_step_index
|
||||
);
|
||||
if (self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL) {
|
||||
self->wildcard_root_pattern_count++;
|
||||
}
|
||||
|
||||
if (capture_count > self->max_capture_count) {
|
||||
self->max_capture_count = capture_count;
|
||||
}
|
||||
|
||||
if (stream.input == stream.end) break;
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
void ts_query_delete(TSQuery *self) {
|
||||
if (self) {
|
||||
array_delete(&self->steps);
|
||||
array_delete(&self->pattern_map);
|
||||
array_delete(&self->capture_data);
|
||||
array_delete(&self->capture_names);
|
||||
ts_free(self);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t ts_query_capture_count(const TSQuery *self) {
|
||||
return self->capture_names.size;
|
||||
}
|
||||
|
||||
const char *ts_query_capture_name_for_id(
|
||||
const TSQuery *self,
|
||||
uint32_t index,
|
||||
uint32_t *length
|
||||
) {
|
||||
CaptureSlice name = self->capture_names.contents[index];
|
||||
*length = name.length;
|
||||
return &self->capture_data.contents[name.offset];
|
||||
}
|
||||
|
||||
int ts_query_capture_id_for_name(
|
||||
const TSQuery *self,
|
||||
const char *name,
|
||||
uint32_t length
|
||||
) {
|
||||
for (unsigned i = 0; i < self->capture_names.size; i++) {
|
||||
CaptureSlice existing = self->capture_names.contents[i];
|
||||
if (
|
||||
existing.length == length &&
|
||||
!strncmp(&self->capture_data.contents[existing.offset], name, length)
|
||||
) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
/***************
|
||||
* QueryContext
|
||||
***************/
|
||||
|
||||
TSQueryContext *ts_query_context_new(const TSQuery *query) {
|
||||
TSQueryContext *self = ts_malloc(sizeof(TSQueryContext));
|
||||
*self = (TSQueryContext) {
|
||||
.query = query,
|
||||
.ascending = false,
|
||||
.states = array_new(),
|
||||
.finished_states = array_new(),
|
||||
.capture_list_pool = capture_list_pool_new(query->max_capture_count),
|
||||
};
|
||||
return self;
|
||||
}
|
||||
|
||||
void ts_query_context_delete(TSQueryContext *self) {
|
||||
array_delete(&self->states);
|
||||
array_delete(&self->finished_states);
|
||||
ts_tree_cursor_delete(&self->cursor);
|
||||
capture_list_pool_delete(&self->capture_list_pool);
|
||||
ts_free(self);
|
||||
}
|
||||
|
||||
void ts_query_context_exec(TSQueryContext *self, TSNode node) {
|
||||
array_clear(&self->states);
|
||||
array_clear(&self->finished_states);
|
||||
ts_tree_cursor_reset(&self->cursor, node);
|
||||
capture_list_pool_clear(&self->capture_list_pool);
|
||||
self->depth = 0;
|
||||
self->ascending = false;
|
||||
}
|
||||
|
||||
bool ts_query_context_next(TSQueryContext *self) {
|
||||
if (self->finished_states.size > 0) {
|
||||
array_pop(&self->finished_states);
|
||||
}
|
||||
|
||||
while (self->finished_states.size == 0) {
|
||||
if (self->ascending) {
|
||||
// Remove any states that were started within this node and are still
|
||||
// not complete.
|
||||
uint32_t deleted_count = 0;
|
||||
for (unsigned i = 0, n = self->states.size; i < n; i++) {
|
||||
QueryState *state = &self->states.contents[i];
|
||||
if (state->start_depth == self->depth) {
|
||||
|
||||
// printf("FAIL STATE pattern: %u, step: %u\n", state->pattern_index, state->step_index);
|
||||
|
||||
capture_list_pool_release(
|
||||
&self->capture_list_pool,
|
||||
state->capture_list_id
|
||||
);
|
||||
deleted_count++;
|
||||
} else if (deleted_count > 0) {
|
||||
self->states.contents[i - deleted_count] = *state;
|
||||
}
|
||||
}
|
||||
|
||||
// if (deleted_count) {
|
||||
// printf("FAILED %u of %u states\n", deleted_count, self->states.size);
|
||||
// }
|
||||
|
||||
self->states.size -= deleted_count;
|
||||
|
||||
if (ts_tree_cursor_goto_next_sibling(&self->cursor)) {
|
||||
self->ascending = false;
|
||||
} else if (ts_tree_cursor_goto_parent(&self->cursor)) {
|
||||
self->depth--;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
TSFieldId field_id = NONE;
|
||||
bool field_occurs_in_later_sibling = false;
|
||||
TSNode node = ts_tree_cursor_current_node(&self->cursor);
|
||||
TSSymbol symbol = ts_node_symbol(node);
|
||||
|
||||
// printf("DESCEND INTO NODE: %s\n", ts_node_type(node));
|
||||
|
||||
// Add new states for any patterns whose root node is a wildcard.
|
||||
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
|
||||
PatternSlice *slice = &self->query->pattern_map.contents[i];
|
||||
QueryStep *step = &self->query->steps.contents[slice->step_index];
|
||||
|
||||
// Check that the node matches the criteria for the first step
|
||||
// of the pattern.
|
||||
if (step->field) {
|
||||
if (field_id == NONE) {
|
||||
field_id = ts_tree_cursor_current_field_id_ext(
|
||||
&self->cursor,
|
||||
&field_occurs_in_later_sibling
|
||||
);
|
||||
}
|
||||
if (field_id != step->field) continue;
|
||||
}
|
||||
|
||||
// Add a new state at the start of this pattern.
|
||||
uint32_t capture_list_id = capture_list_pool_acquire(
|
||||
&self->capture_list_pool
|
||||
);
|
||||
if (capture_list_id == NONE) break;
|
||||
array_push(&self->states, ((QueryState) {
|
||||
.step_index = slice->step_index,
|
||||
.pattern_index = slice->pattern_index,
|
||||
.capture_list_id = capture_list_id,
|
||||
}));
|
||||
}
|
||||
|
||||
// Add new states for any patterns whose root node matches this node.
|
||||
unsigned i;
|
||||
if (ts_query__pattern_map_search(self->query, symbol, &i)) {
|
||||
PatternSlice *slice = &self->query->pattern_map.contents[i];
|
||||
QueryStep *step = &self->query->steps.contents[slice->step_index];
|
||||
do {
|
||||
if (step->field) {
|
||||
if (field_id == NONE) {
|
||||
field_id = ts_tree_cursor_current_field_id_ext(
|
||||
&self->cursor,
|
||||
&field_occurs_in_later_sibling
|
||||
);
|
||||
}
|
||||
if (field_id != step->field) continue;
|
||||
}
|
||||
|
||||
// printf("START NEW STATE: %u\n", slice->pattern_index);
|
||||
|
||||
// If the node matches the first step of the pattern, then add
|
||||
// a new in-progress state. First, acquire a list to hold the
|
||||
// pattern's captures.
|
||||
uint32_t capture_list_id = capture_list_pool_acquire(
|
||||
&self->capture_list_pool
|
||||
);
|
||||
if (capture_list_id == NONE) break;
|
||||
|
||||
array_push(&self->states, ((QueryState) {
|
||||
.pattern_index = slice->pattern_index,
|
||||
.step_index = slice->step_index + 1,
|
||||
.start_depth = self->depth,
|
||||
.capture_list_id = capture_list_id,
|
||||
.capture_count = 0,
|
||||
}));
|
||||
|
||||
i++;
|
||||
if (i == self->query->pattern_map.size) break;
|
||||
slice = &self->query->pattern_map.contents[i];
|
||||
step = &self->query->steps.contents[slice->step_index];
|
||||
} while (step->symbol == symbol);
|
||||
}
|
||||
|
||||
// Update all of the in-progress states with current node.
|
||||
for (unsigned i = 0, n = self->states.size; i < n; i++) {
|
||||
QueryState *state = &self->states.contents[i];
|
||||
QueryStep *step = &self->query->steps.contents[state->step_index];
|
||||
|
||||
// Check that the node matches all of the criteria for the next
|
||||
// step of the pattern.
|
||||
if (state->start_depth + step->depth != self->depth) continue;
|
||||
if (step->symbol && step->symbol != symbol) continue;
|
||||
if (step->field) {
|
||||
// Only compute the current field if it is needed for the current
|
||||
// step of some in-progress pattern.
|
||||
if (field_id == NONE) {
|
||||
field_id = ts_tree_cursor_current_field_id_ext(
|
||||
&self->cursor,
|
||||
&field_occurs_in_later_sibling
|
||||
);
|
||||
}
|
||||
if (field_id != step->field) continue;
|
||||
}
|
||||
|
||||
// Some patterns can match their root node in multiple ways,
|
||||
// capturing different children. If this pattern step could match
|
||||
// later children within the same parent, then this query state
|
||||
// cannot simply be updated in place. It must be split into two
|
||||
// states: one that captures this node, and one which skips over
|
||||
// this node, to preserve the possibility of capturing later
|
||||
// siblings.
|
||||
QueryState *next_state = state;
|
||||
if (step->depth > 0 && (!step->field || field_occurs_in_later_sibling)) {
|
||||
uint32_t capture_list_id = capture_list_pool_acquire(
|
||||
&self->capture_list_pool
|
||||
);
|
||||
if (capture_list_id != NONE) {
|
||||
array_push(&self->states, *state);
|
||||
next_state = array_back(&self->states);
|
||||
next_state->capture_list_id = capture_list_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Record captures
|
||||
if (step->capture_id != NONE) {
|
||||
// printf("CAPTURE id: %u\n", step->capture_id);
|
||||
|
||||
TSQueryCapture *capture_list = capture_list_pool_get(
|
||||
&self->capture_list_pool,
|
||||
next_state->capture_list_id
|
||||
);
|
||||
capture_list[next_state->capture_count++] = (TSQueryCapture) {
|
||||
node,
|
||||
step->capture_id
|
||||
};
|
||||
}
|
||||
|
||||
// If the pattern is now done, then populate the query context's
|
||||
// finished state.
|
||||
next_state->step_index++;
|
||||
QueryStep *next_step = step + 1;
|
||||
if (next_step->depth == PATTERN_DONE_MARKER) {
|
||||
// printf("FINISHED MATCH pattern: %u\n", next_state->pattern_index);
|
||||
|
||||
array_push(&self->finished_states, *next_state);
|
||||
if (next_state == state) {
|
||||
array_erase(&self->states, i);
|
||||
i--;
|
||||
n--;
|
||||
} else {
|
||||
array_pop(&self->states);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ts_tree_cursor_goto_first_child(&self->cursor)) {
|
||||
self->depth++;
|
||||
} else {
|
||||
self->ascending = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t ts_query_context_matched_pattern_index(const TSQueryContext *self) {
|
||||
if (self->finished_states.size > 0) {
|
||||
QueryState *state = array_back(&self->finished_states);
|
||||
return state->pattern_index;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
const TSQueryCapture *ts_query_context_matched_captures(
|
||||
const TSQueryContext *self,
|
||||
uint32_t *count
|
||||
) {
|
||||
if (self->finished_states.size > 0) {
|
||||
QueryState *state = array_back(&self->finished_states);
|
||||
*count = state->capture_count;
|
||||
return capture_list_pool_get(
|
||||
(CaptureListPool *)&self->capture_list_pool,
|
||||
state->capture_list_id
|
||||
);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
|
@ -244,7 +244,12 @@ TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) {
|
|||
);
|
||||
}
|
||||
|
||||
TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
|
||||
static inline TSFieldId ts_tree_cursor__current_field_info(
|
||||
const TSTreeCursor *_self,
|
||||
const TSFieldMapEntry **field_map,
|
||||
const TSFieldMapEntry **field_map_end,
|
||||
uint32_t *child_index
|
||||
) {
|
||||
const TreeCursor *self = (const TreeCursor *)_self;
|
||||
|
||||
// Walk up the tree, visiting the current node and its invisible ancestors.
|
||||
|
|
@ -264,25 +269,61 @@ TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
|
|||
}
|
||||
}
|
||||
|
||||
const TSFieldMapEntry *field_map, *field_map_end;
|
||||
if (ts_subtree_extra(*entry->subtree)) break;
|
||||
|
||||
ts_language_field_map(
|
||||
self->tree->language,
|
||||
parent_entry->subtree->ptr->production_id,
|
||||
&field_map, &field_map_end
|
||||
field_map, field_map_end
|
||||
);
|
||||
|
||||
while (field_map < field_map_end) {
|
||||
if (
|
||||
!ts_subtree_extra(*entry->subtree) &&
|
||||
!field_map->inherited &&
|
||||
field_map->child_index == entry->structural_child_index
|
||||
) return field_map->field_id;
|
||||
field_map++;
|
||||
for (const TSFieldMapEntry *i = *field_map; i < *field_map_end; i++) {
|
||||
if (!i->inherited && i->child_index == entry->structural_child_index) {
|
||||
*child_index = entry->structural_child_index;
|
||||
return i->field_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
TSFieldId ts_tree_cursor_current_field_id_ext(
|
||||
const TSTreeCursor *self,
|
||||
bool *field_has_additional
|
||||
) {
|
||||
uint32_t child_index;
|
||||
const TSFieldMapEntry *field_map, *field_map_end;
|
||||
TSFieldId field_id = ts_tree_cursor__current_field_info(
|
||||
self,
|
||||
&field_map,
|
||||
&field_map_end,
|
||||
&child_index
|
||||
);
|
||||
|
||||
// After finding the field, check if any other later children have
|
||||
// the same field name.
|
||||
if (field_id) {
|
||||
for (const TSFieldMapEntry *i = field_map; i < field_map_end; i++) {
|
||||
if (i->field_id == field_id && i->child_index > child_index) {
|
||||
*field_has_additional = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return field_id;
|
||||
}
|
||||
|
||||
|
||||
TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *self) {
|
||||
uint32_t child_index;
|
||||
const TSFieldMapEntry *field_map, *field_map_end;
|
||||
return ts_tree_cursor__current_field_info(
|
||||
self,
|
||||
&field_map,
|
||||
&field_map_end,
|
||||
&child_index
|
||||
);
|
||||
}
|
||||
|
||||
const char *ts_tree_cursor_current_field_name(const TSTreeCursor *_self) {
|
||||
TSFieldId id = ts_tree_cursor_current_field_id(_self);
|
||||
if (id) {
|
||||
|
|
|
|||
|
|
@ -16,5 +16,6 @@ typedef struct {
|
|||
} TreeCursor;
|
||||
|
||||
void ts_tree_cursor_init(TreeCursor *, TSNode);
|
||||
TSFieldId ts_tree_cursor_current_field_id_ext(const TSTreeCursor *, bool *);
|
||||
|
||||
#endif // TREE_SITTER_TREE_CURSOR_H_
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue