Start work on an API for querying trees

This commit is contained in:
Max Brunsfeld 2019-09-09 15:41:13 -07:00
parent 4151a428ec
commit fe7c74e7aa
10 changed files with 1430 additions and 12 deletions

View file

@ -4,4 +4,5 @@ mod highlight_test;
mod node_test;
mod parser_test;
mod properties_test;
mod query_test;
mod tree_test;

216
cli/src/tests/query_test.rs Normal file
View file

@ -0,0 +1,216 @@
use super::helpers::allocations;
use super::helpers::fixtures::get_language;
use tree_sitter::{Parser, Query, QueryError, QueryMatch};
#[test]
fn test_query_errors_on_invalid_syntax() {
allocations::start_recording();
let language = get_language("javascript");
assert!(Query::new(language, "(if_statement)").is_ok());
assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok());
// Mismatched parens
assert_eq!(
Query::new(language, "(if_statement"),
Err(QueryError::Syntax(13))
);
assert_eq!(
Query::new(language, "(if_statement))"),
Err(QueryError::Syntax(14))
);
// Return an error at the *beginning* of a bare identifier not followed a colon.
// If there's a colon but no pattern, return an error at the end of the colon.
assert_eq!(
Query::new(language, "(if_statement identifier)"),
Err(QueryError::Syntax(14))
);
assert_eq!(
Query::new(language, "(if_statement condition:)"),
Err(QueryError::Syntax(24))
);
assert_eq!(
Query::new(language, "(if_statement condition:)"),
Err(QueryError::Syntax(24))
);
allocations::stop_recording();
}
#[test]
fn test_query_errors_on_invalid_symbols() {
allocations::start_recording();
let language = get_language("javascript");
assert_eq!(
Query::new(language, "(non_existent1)"),
Err(QueryError::NodeType("non_existent1"))
);
assert_eq!(
Query::new(language, "(if_statement (non_existent2))"),
Err(QueryError::NodeType("non_existent2"))
);
assert_eq!(
Query::new(language, "(if_statement condition: (non_existent3))"),
Err(QueryError::NodeType("non_existent3"))
);
assert_eq!(
Query::new(language, "(if_statement not_a_field: (identifier))"),
Err(QueryError::Field("not_a_field"))
);
allocations::stop_recording();
}
#[test]
fn test_query_capture_names() {
allocations::start_recording();
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(if_statement
condition: (binary_expression
left: * @left-operand
operator: "||"
right: * @right-operand)
consequence: (statement_block) @body)
(while_statement
condition:* @loop-condition)
"#,
)
.unwrap();
assert_eq!(
query.capture_names(),
&[
"left-operand".to_string(),
"right-operand".to_string(),
"body".to_string(),
"loop-condition".to_string(),
]
);
drop(query);
allocations::stop_recording();
}
#[test]
fn test_query_exec_with_simple_pattern() {
allocations::start_recording();
let language = get_language("javascript");
let query = Query::new(
language,
"(function_declaration name: (identifier) @fn-name)",
)
.unwrap();
let source = "function one() { two(); function three() {} }";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let context = query.context();
let matches = context.exec(tree.root_node());
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("fn-name", "one")]),
(0, vec![("fn-name", "three")])
],
);
drop(context);
drop(parser);
drop(query);
drop(tree);
allocations::stop_recording();
}
#[test]
fn test_query_exec_with_multiple_matches_same_root() {
allocations::start_recording();
let language = get_language("javascript");
let query = Query::new(
language,
"(class_declaration
name: (identifier) @the-class-name
(class_body
(method_definition
name: (property_identifier) @the-method-name)))",
)
.unwrap();
let source = "
class Person {
// the constructor
constructor(name) { this.name = name; }
// the getter
getFullName() { return this.name; }
}
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let context = query.context();
let matches = context.exec(tree.root_node());
assert_eq!(
collect_matches(matches, &query, source),
&[
(
0,
vec![
("the-class-name", "Person"),
("the-method-name", "constructor")
]
),
(
0,
vec![
("the-class-name", "Person"),
("the-method-name", "getFullName")
]
),
],
);
drop(context);
drop(parser);
drop(query);
drop(tree);
allocations::stop_recording();
}
fn collect_matches<'a>(
matches: impl Iterator<Item = QueryMatch<'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(usize, Vec<(&'a str, &'a str)>)> {
matches
.map(|m| {
(
m.pattern_index(),
m.captures()
.map(|(capture_id, node)| {
(
query.capture_names()[capture_id].as_str(),
node.utf8_text(source.as_bytes()).unwrap(),
)
})
.collect(),
)
})
.collect()
}

View file

@ -19,6 +19,16 @@ pub struct TSParser {
pub struct TSTree {
_unused: [u8; 0],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQuery {
_unused: [u8; 0],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryContext {
_unused: [u8; 0],
}
pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0;
pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1;
pub type TSInputEncoding = u32;
@ -93,6 +103,17 @@ pub struct TSTreeCursor {
pub id: *const ::std::os::raw::c_void,
pub context: [u32; 2usize],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryCapture {
pub node: TSNode,
pub index: u32,
}
pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0;
pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1;
pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
pub type TSQueryError = u32;
extern "C" {
#[doc = " Create a new parser."]
pub fn ts_parser_new() -> *mut TSParser;
@ -538,6 +559,65 @@ extern "C" {
extern "C" {
pub fn ts_tree_cursor_copy(arg1: *const TSTreeCursor) -> TSTreeCursor;
}
extern "C" {
#[doc = " Create a new query based on a given language and string containing"]
#[doc = " one or more S-expression patterns."]
#[doc = ""]
#[doc = " If all of the given patterns are valid, this returns a `TSQuery`."]
#[doc = " If a pattern is invalid, this returns `NULL`, and provides two pieces"]
#[doc = " of information about the problem:"]
#[doc = " 1. The byte offset of the error is written to the `error_offset` parameter."]
#[doc = " 2. The type of error is written to the `error_type` parameter."]
pub fn ts_query_new(
arg1: *const TSLanguage,
source: *const ::std::os::raw::c_char,
source_len: u32,
error_offset: *mut u32,
error_type: *mut TSQueryError,
) -> *mut TSQuery;
}
extern "C" {
#[doc = " Delete a query, freeing all of the memory that it used."]
pub fn ts_query_delete(arg1: *mut TSQuery);
}
extern "C" {
pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
pub fn ts_query_capture_name_for_id(
self_: *const TSQuery,
index: u32,
length: *mut u32,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
pub fn ts_query_capture_id_for_name(
self_: *const TSQuery,
name: *const ::std::os::raw::c_char,
length: u32,
) -> ::std::os::raw::c_int;
}
extern "C" {
pub fn ts_query_context_new(arg1: *const TSQuery) -> *mut TSQueryContext;
}
extern "C" {
pub fn ts_query_context_delete(arg1: *mut TSQueryContext);
}
extern "C" {
pub fn ts_query_context_exec(arg1: *mut TSQueryContext, arg2: TSNode);
}
extern "C" {
pub fn ts_query_context_next(arg1: *mut TSQueryContext) -> bool;
}
extern "C" {
pub fn ts_query_context_matched_pattern_index(arg1: *const TSQueryContext) -> u32;
}
extern "C" {
pub fn ts_query_context_matched_captures(
arg1: *const TSQueryContext,
arg2: *mut u32,
) -> *const TSQueryCapture;
}
extern "C" {
#[doc = " Get the number of distinct node types in the language."]
pub fn ts_language_symbol_count(arg1: *const TSLanguage) -> u32;

View file

@ -17,7 +17,7 @@ use std::ffi::CStr;
use std::marker::PhantomData;
use std::os::raw::{c_char, c_void};
use std::sync::atomic::AtomicUsize;
use std::{fmt, ptr, str, u16};
use std::{char, fmt, ptr, slice, str, u16};
pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION;
pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h");
@ -136,6 +136,23 @@ pub struct TreePropertyCursor<'a, P> {
source: &'a [u8],
}
#[derive(Debug)]
pub struct Query {
ptr: *mut ffi::TSQuery,
capture_names: Vec<String>,
}
pub struct QueryContext<'a>(*mut ffi::TSQueryContext, PhantomData<&'a ()>);
pub struct QueryMatch<'a>(&'a QueryContext<'a>);
#[derive(Debug, PartialEq, Eq)]
pub enum QueryError<'a> {
Syntax(usize),
NodeType(&'a str),
Field(&'a str),
}
impl Language {
pub fn version(&self) -> usize {
unsafe { ffi::ts_language_version(self.0) as usize }
@ -921,6 +938,117 @@ impl<'a, P> TreePropertyCursor<'a, P> {
}
}
impl Query {
pub fn new(language: Language, source: &str) -> Result<Self, QueryError> {
let mut error_offset = 0u32;
let mut error_type: ffi::TSQueryError = 0;
let bytes = source.as_bytes();
let ptr = unsafe {
ffi::ts_query_new(
language.0,
bytes.as_ptr() as *const c_char,
bytes.len() as u32,
&mut error_offset as *mut u32,
&mut error_type as *mut ffi::TSQueryError,
)
};
if ptr.is_null() {
let offset = error_offset as usize;
Err(match error_type {
ffi::TSQueryError_TSQueryErrorNodeType | ffi::TSQueryError_TSQueryErrorField => {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0;
if error_type == ffi::TSQueryError_TSQueryErrorNodeType {
QueryError::NodeType(name)
} else {
QueryError::Field(name)
}
}
_ => QueryError::Syntax(offset),
})
} else {
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
let capture_names = (0..capture_count)
.map(|i| unsafe {
let mut length = 0u32;
let name =
ffi::ts_query_capture_name_for_id(ptr, i as u32, &mut length as *mut u32)
as *const u8;
let name = slice::from_raw_parts(name, length as usize);
let name = str::from_utf8_unchecked(name);
name.to_string()
})
.collect();
Ok(Query { ptr, capture_names })
}
}
pub fn capture_names(&self) -> &[String] {
&self.capture_names
}
pub fn context(&self) -> QueryContext {
let context = unsafe { ffi::ts_query_context_new(self.ptr) };
QueryContext(context, PhantomData)
}
}
impl<'a> QueryContext<'a> {
pub fn exec(&'a self, node: Node<'a>) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
unsafe {
ffi::ts_query_context_exec(self.0, node.0);
}
std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
unsafe {
if ffi::ts_query_context_next(self.0) {
Some(QueryMatch(self))
} else {
None
}
}
})
}
}
impl<'a> QueryMatch<'a> {
pub fn pattern_index(&self) -> usize {
unsafe { ffi::ts_query_context_matched_pattern_index((self.0).0) as usize }
}
pub fn captures(&self) -> impl ExactSizeIterator<Item = (usize, Node)> {
unsafe {
let mut capture_count = 0u32;
let captures =
ffi::ts_query_context_matched_captures((self.0).0, &mut capture_count as *mut u32);
let captures = slice::from_raw_parts(captures, capture_count as usize);
captures
.iter()
.map(move |capture| (capture.index as usize, Node::new(capture.node).unwrap()))
}
}
}
impl PartialEq for Query {
fn eq(&self, other: &Self) -> bool {
self.ptr == other.ptr
}
}
impl Drop for Query {
fn drop(&mut self) {
unsafe { ffi::ts_query_delete(self.ptr) }
}
}
impl<'a> Drop for QueryContext<'a> {
fn drop(&mut self) {
unsafe { ffi::ts_query_context_delete(self.0) }
}
}
impl Point {
pub fn new(row: usize, column: usize) -> Self {
Point { row, column }

View file

@ -26,6 +26,8 @@ typedef uint16_t TSFieldId;
typedef struct TSLanguage TSLanguage;
typedef struct TSParser TSParser;
typedef struct TSTree TSTree;
typedef struct TSQuery TSQuery;
typedef struct TSQueryContext TSQueryContext;
typedef enum {
TSInputEncodingUTF8,
@ -87,6 +89,18 @@ typedef struct {
uint32_t context[2];
} TSTreeCursor;
typedef struct {
TSNode node;
uint32_t index;
} TSQueryCapture;
typedef enum {
TSQueryErrorNone = 0,
TSQueryErrorSyntax,
TSQueryErrorNodeType,
TSQueryErrorField,
} TSQueryError;
/********************/
/* Section - Parser */
/********************/
@ -602,6 +616,107 @@ int64_t ts_tree_cursor_goto_first_child_for_byte(TSTreeCursor *, uint32_t);
TSTreeCursor ts_tree_cursor_copy(const TSTreeCursor *);
/*******************/
/* Section - Query */
/*******************/
/**
* Create a new query from a string containing one or more S-expression
* patterns. The query is associated with a particular language, and can
* only be run on syntax nodes parsed with that language.
*
* If all of the given patterns are valid, this returns a `TSQuery`.
* If a pattern is invalid, this returns `NULL`, and provides two pieces
* of information about the problem:
* 1. The byte offset of the error is written to the `error_offset` parameter.
* 2. The type of error is written to the `error_type` parameter.
*/
TSQuery *ts_query_new(
const TSLanguage *language,
const char *source,
uint32_t source_len,
uint32_t *error_offset,
TSQueryError *error_type
);
/**
* Delete a query, freeing all of the memory that it used.
*/
void ts_query_delete(TSQuery *);
/*
* Get the number of distinct capture names in the query.
*/
uint32_t ts_query_capture_count(const TSQuery *);
/*
* Get the name and length of one of the query's capture. Each capture
* is associated with a numeric id based on the order that it appeared
* in the query's source.
*/
const char *ts_query_capture_name_for_id(
const TSQuery *self,
uint32_t index,
uint32_t *length
);
/*
* Get the numeric id of the capture with the given name.
*/
int ts_query_capture_id_for_name(
const TSQuery *self,
const char *name,
uint32_t length
);
/*
* Create a new context for executing a given query.
*
* The context stores the state that is needed to iteratively search
* for matches. To use the query context:
* 1. First call `ts_query_context_exec` to start running the query
* on a particular syntax node.
* 2. Then repeatedly call `ts_query_context_next` to iterate over
* the matches.
* 3. For each match, you can call `ts_query_context_matched_pattern_index`
* to determine which pattern matched. You can also call
* `ts_query_context_matched_captures` to determine which nodes
* were captured by which capture names.
*
* If you don't care about finding all of the matches, you can stop calling
* `ts_query_context_next` at any point. And you can start executing the
* query against a different node by calling `ts_query_context_exec` again.
*/
TSQueryContext *ts_query_context_new(const TSQuery *);
/*
* Delete a query context, freeing all of the memory that it used.
*/
void ts_query_context_delete(TSQueryContext *);
/*
* Start running a query on a given node.
*/
void ts_query_context_exec(TSQueryContext *, TSNode);
/*
* Advance to the next match of the currently running query.
*/
bool ts_query_context_next(TSQueryContext *);
/*
* Check which pattern matched.
*/
uint32_t ts_query_context_matched_pattern_index(const TSQueryContext *);
/*
* Check which pattern matched.
*/
const TSQueryCapture *ts_query_context_matched_captures(
const TSQueryContext *,
uint32_t *
);
/**********************/
/* Section - Language */
/**********************/

25
lib/src/bits.h Normal file
View file

@ -0,0 +1,25 @@
#ifndef TREE_SITTER_BITS_H_
#define TREE_SITTER_BITS_H_
#include <stdint.h>
#ifdef _WIN32
#include <intrin.h>
static inline uint32_t count_leading_zeros(uint32_t x) {
if (x == 0) return 32;
uint32_t result;
_BitScanReverse(&reuslt, x);
return result;
}
#else
static inline uint32_t count_leading_zeros(uint32_t x) {
if (x == 0) return 32;
return __builtin_clz(x);
}
#endif
#endif // TREE_SITTER_BITS_H_

View file

@ -12,6 +12,7 @@
#include "./lexer.c"
#include "./node.c"
#include "./parser.c"
#include "./query.c"
#include "./stack.c"
#include "./subtree.c"
#include "./tree_cursor.c"

810
lib/src/query.c Normal file
View file

@ -0,0 +1,810 @@
#include "tree_sitter/api.h"
#include "./alloc.h"
#include "./array.h"
#include "./bits.h"
#include "utf8proc.h"
#include <wctype.h>
/*
* Stream - A sequence of unicode characters derived from a UTF8 string.
* This struct is used in parsing query S-expressions.
*/
typedef struct {
const char *input;
const char *end;
int32_t next;
uint8_t next_size;
} Stream;
/*
* QueryStep - A step in the process of matching a query. Each node within
* a query S-expression maps to one of these steps. An entire pattern is
* represented as a sequence of these steps.
*/
typedef struct {
TSSymbol symbol;
TSFieldId field;
uint16_t capture_id;
uint8_t depth;
bool field_is_multiple;
} QueryStep;
/*
* CaptureSlice - The name of a capture, represented as a slice of a
* shared string.
*/
typedef struct {
uint32_t offset;
uint32_t length;
} CaptureSlice;
/*
* PatternSlice - The set of steps needed to match a particular pattern,
* represented as a slice of a shared array.
*/
typedef struct {
uint16_t step_index;
uint16_t pattern_index;
} PatternSlice;
/*
* QueryState - The state of an in-progress match of a particular pattern
* in a query. While executing, a QueryContext must keep track of a number
* of possible in-progress matches. Each of those possible matches is
* represented as one of these states.
*/
typedef struct {
uint16_t step_index;
uint16_t pattern_index;
uint16_t start_depth;
uint16_t capture_list_id;
uint16_t capture_count;
} QueryState;
/*
* CaptureListPool - A collection of *lists* of captures. Each QueryState
* needs to maintain its own list of captures. They are all represented as
* slices of one shared array. The CaptureListPool keeps track of which
* parts of the shared array are currently in use by a QueryState.
*/
typedef struct {
TSQueryCapture *contents;
uint32_t list_size;
uint32_t usage_map;
} CaptureListPool;
/*
* TSQuery - A tree query, compiled from a string of S-expressions. The query
* itself is immutable. The mutable state used in the process of executing the
* query is stored in a `TSQueryContext`.
*/
struct TSQuery {
Array(QueryStep) steps;
Array(char) capture_data;
Array(CaptureSlice) capture_names;
Array(PatternSlice) pattern_map;
const TSLanguage *language;
uint16_t max_capture_count;
uint16_t wildcard_root_pattern_count;
};
/*
* TSQueryContext - A stateful struct used to execute a query on a tree.
*/
struct TSQueryContext {
const TSQuery *query;
TSTreeCursor cursor;
Array(QueryState) states;
Array(QueryState) finished_states;
CaptureListPool capture_list_pool;
bool ascending;
uint32_t depth;
};
static const TSQueryError PARENT_DONE = -1;
static const uint8_t PATTERN_DONE_MARKER = UINT8_MAX;
static const uint16_t NONE = UINT16_MAX;
static const TSSymbol WILDCARD_SYMBOL = 0;
static const uint16_t MAX_STATE_COUNT = 32;
/**********
* Stream
**********/
static bool stream_advance(Stream *self) {
if (self->input >= self->end) return false;
self->input += self->next_size;
int size = utf8proc_iterate(
(const uint8_t *)self->input,
self->end - self->input,
&self->next
);
if (size <= 0) return false;
self->next_size = size;
return true;
}
static void stream_reset(Stream *self, const char *input) {
self->input = input;
self->next_size = 0;
stream_advance(self);
}
static Stream stream_new(const char *string, uint32_t length) {
Stream self = {
.next = 0,
.input = string,
.end = string + length,
};
stream_advance(&self);
return self;
}
static void stream_skip_whitespace(Stream *stream) {
while (iswspace(stream->next)) stream_advance(stream);
}
static bool stream_is_ident_start(Stream *stream) {
return iswalpha(stream->next) || stream->next == '_' || stream->next == '-';
}
static void stream_scan_identifier(Stream *stream) {
do {
stream_advance(stream);
} while (
iswalnum(stream->next) ||
stream->next == '_' ||
stream->next == '-' ||
stream->next == '.'
);
}
/******************
* CaptureListPool
******************/
static CaptureListPool capture_list_pool_new(uint16_t list_size) {
return (CaptureListPool) {
.contents = ts_calloc(MAX_STATE_COUNT * list_size, sizeof(TSQueryCapture)),
.list_size = list_size,
.usage_map = UINT32_MAX,
};
}
static void capture_list_pool_clear(CaptureListPool *self) {
self->usage_map = UINT32_MAX;
}
static void capture_list_pool_delete(CaptureListPool *self) {
ts_free(self->contents);
}
static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id) {
return &self->contents[id * self->list_size];
}
static uint16_t capture_list_pool_acquire(CaptureListPool *self) {
uint16_t id = count_leading_zeros(self->usage_map);
if (id == 32) return NONE;
self->usage_map &= ~(1 << id);
return id;
}
static void capture_list_pool_release(CaptureListPool *self, uint16_t id) {
self->usage_map |= (1 << id);
}
/*********
* Query
*********/
static TSSymbol ts_query_intern_node_name(
const TSQuery *self,
const char *name,
uint32_t length,
TSSymbolType symbol_type
) {
uint32_t symbol_count = ts_language_symbol_count(self->language);
for (TSSymbol i = 0; i < symbol_count; i++) {
if (
ts_language_symbol_type(self->language, i) == symbol_type &&
!strncmp(ts_language_symbol_name(self->language, i), name, length)
) return i;
}
return 0;
}
static uint16_t ts_query_intern_capture_name(
TSQuery *self,
const char *name,
uint32_t length
) {
int id = ts_query_capture_id_for_name(self, name, length);
if (id >= 0) {
return (uint16_t)id;
}
CaptureSlice capture = {
.offset = self->capture_data.size,
.length = length,
};
array_grow_by(&self->capture_data, length + 1);
memcpy(&self->capture_data.contents[capture.offset], name, length);
self->capture_data.contents[self->capture_data.size - 1] = 0;
array_push(&self->capture_names, capture);
return self->capture_names.size - 1;
}
static inline bool ts_query__pattern_map_search(
const TSQuery *self,
TSSymbol needle,
uint32_t *result
) {
uint32_t base_index = self->wildcard_root_pattern_count;
uint32_t size = self->pattern_map.size - base_index;
if (size == 0) {
*result = base_index;
return false;
}
while (size > 1) {
uint32_t half_size = size / 2;
uint32_t mid_index = base_index + half_size;
TSSymbol mid_symbol = self->steps.contents[
self->pattern_map.contents[mid_index].step_index
].symbol;
if (needle > mid_symbol) base_index = mid_index;
size -= half_size;
}
TSSymbol symbol = self->steps.contents[
self->pattern_map.contents[base_index].step_index
].symbol;
if (needle > symbol) {
*result = base_index;
return false;
} else if (needle == symbol) {
*result = base_index;
return true;
} else {
*result = base_index + 1;
return false;
}
}
static inline void ts_query__pattern_map_insert(
TSQuery *self,
TSSymbol symbol,
uint32_t start_step_index
) {
uint32_t index;
ts_query__pattern_map_search(self, symbol, &index);
array_insert(&self->pattern_map, index, ((PatternSlice) {
.step_index = start_step_index,
.pattern_index = self->pattern_map.size,
}));
}
static TSQueryError ts_query_parse_pattern(
TSQuery *self,
Stream *stream,
uint32_t depth,
uint32_t *capture_count
) {
uint16_t starting_step_index = self->steps.size;
if (stream->next == 0) return TSQueryErrorSyntax;
// Finish the parent S-expression
if (stream->next == ')') {
return PARENT_DONE;
}
// Parse a parenthesized node expression
else if (stream->next == '(') {
stream_advance(stream);
stream_skip_whitespace(stream);
TSSymbol symbol;
// Parse the wildcard symbol
if (stream->next == '*') {
symbol = WILDCARD_SYMBOL;
stream_advance(stream);
}
// Parse a normal node name
else if (stream_is_ident_start(stream)) {
const char *node_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - node_name;
symbol = ts_query_intern_node_name(
self,
node_name,
length,
TSSymbolTypeRegular
);
if (!symbol) {
stream_reset(stream, node_name);
return TSQueryErrorNodeType;
}
} else {
return TSQueryErrorSyntax;
}
// Add a step for the node.
array_push(&self->steps, ((QueryStep) {
.depth = depth,
.symbol = symbol,
.field = 0,
.capture_id = NONE,
}));
// Parse the child patterns
stream_skip_whitespace(stream);
for (;;) {
TSQueryError e = ts_query_parse_pattern(self, stream, depth + 1, capture_count);
if (e == PARENT_DONE) {
stream_advance(stream);
break;
} else if (e) {
return e;
}
}
}
// Parse a double-quoted anonymous leaf node expression
else if (stream->next == '"') {
stream_advance(stream);
// Parse the string content
const char *string_content = stream->input;
while (stream->next && stream->next != '"') stream_advance(stream);
uint32_t length = stream->input - string_content;
// Add a step for the node
TSSymbol symbol = ts_query_intern_node_name(
self,
string_content,
length,
TSSymbolTypeAnonymous
);
if (!symbol) {
stream_reset(stream, string_content);
return TSQueryErrorNodeType;
}
array_push(&self->steps, ((QueryStep) {
.depth = depth,
.symbol = symbol,
.field = 0,
}));
if (stream->next != '"') return TSQueryErrorSyntax;
stream_advance(stream);
}
// Parse a field-prefixed pattern
else if (stream_is_ident_start(stream)) {
// Parse the field name
const char *field_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - field_name;
stream_skip_whitespace(stream);
if (stream->next != ':') {
stream_reset(stream, field_name);
return TSQueryErrorSyntax;
}
stream_advance(stream);
stream_skip_whitespace(stream);
// Parse the pattern
uint32_t step_index = self->steps.size;
TSQueryError e = ts_query_parse_pattern(self, stream, depth, capture_count);
if (e == PARENT_DONE) return TSQueryErrorSyntax;
if (e) return e;
// Add the field name to the first step of the pattern
TSFieldId field_id = ts_language_field_id_for_name(
self->language,
field_name,
length
);
if (!field_id) {
stream->input = field_name;
return TSQueryErrorField;
}
self->steps.contents[step_index].field = field_id;
}
// Parse a wildcard pattern
else if (stream->next == '*') {
stream_advance(stream);
stream_skip_whitespace(stream);
// Add a step that matches any kind of node
array_push(&self->steps, ((QueryStep) {
.depth = depth,
.symbol = WILDCARD_SYMBOL,
.field = 0,
}));
}
// No match
else {
return TSQueryErrorSyntax;
}
stream_skip_whitespace(stream);
// Parse a '@'-suffixed capture pattern
if (stream->next == '@') {
stream_advance(stream);
stream_skip_whitespace(stream);
// Parse the capture name
if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax;
const char *capture_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - capture_name;
// Add the capture id to the first step of the pattern
uint16_t capture_id = ts_query_intern_capture_name(
self,
capture_name,
length
);
self->steps.contents[starting_step_index].capture_id = capture_id;
(*capture_count)++;
stream_skip_whitespace(stream);
}
return 0;
}
TSQuery *ts_query_new(
const TSLanguage *language,
const char *source,
uint32_t source_len,
uint32_t *error_offset,
TSQueryError *error_type
) {
TSQuery *self = ts_malloc(sizeof(TSQuery));
*self = (TSQuery) {
.steps = array_new(),
.pattern_map = array_new(),
.wildcard_root_pattern_count = 0,
.max_capture_count = 0,
.language = language,
};
// Parse all of the S-expressions in the given string.
Stream stream = stream_new(source, source_len);
stream_skip_whitespace(&stream);
uint32_t start_step_index;
for (;;) {
start_step_index = self->steps.size;
uint32_t capture_count = 0;
*error_type = ts_query_parse_pattern(self, &stream, 0, &capture_count);
array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER }));
// If any pattern could not be parsed, then report the error information
// and terminate.
if (*error_type) {
*error_offset = stream.input - source;
ts_query_delete(self);
return NULL;
}
// Maintain a map that can look up patterns for a given root symbol.
ts_query__pattern_map_insert(
self,
self->steps.contents[start_step_index].symbol,
start_step_index
);
if (self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL) {
self->wildcard_root_pattern_count++;
}
if (capture_count > self->max_capture_count) {
self->max_capture_count = capture_count;
}
if (stream.input == stream.end) break;
}
return self;
}
void ts_query_delete(TSQuery *self) {
if (self) {
array_delete(&self->steps);
array_delete(&self->pattern_map);
array_delete(&self->capture_data);
array_delete(&self->capture_names);
ts_free(self);
}
}
uint32_t ts_query_capture_count(const TSQuery *self) {
return self->capture_names.size;
}
const char *ts_query_capture_name_for_id(
const TSQuery *self,
uint32_t index,
uint32_t *length
) {
CaptureSlice name = self->capture_names.contents[index];
*length = name.length;
return &self->capture_data.contents[name.offset];
}
int ts_query_capture_id_for_name(
const TSQuery *self,
const char *name,
uint32_t length
) {
for (unsigned i = 0; i < self->capture_names.size; i++) {
CaptureSlice existing = self->capture_names.contents[i];
if (
existing.length == length &&
!strncmp(&self->capture_data.contents[existing.offset], name, length)
) return i;
}
return -1;
}
/***************
* QueryContext
***************/
TSQueryContext *ts_query_context_new(const TSQuery *query) {
TSQueryContext *self = ts_malloc(sizeof(TSQueryContext));
*self = (TSQueryContext) {
.query = query,
.ascending = false,
.states = array_new(),
.finished_states = array_new(),
.capture_list_pool = capture_list_pool_new(query->max_capture_count),
};
return self;
}
void ts_query_context_delete(TSQueryContext *self) {
array_delete(&self->states);
array_delete(&self->finished_states);
ts_tree_cursor_delete(&self->cursor);
capture_list_pool_delete(&self->capture_list_pool);
ts_free(self);
}
void ts_query_context_exec(TSQueryContext *self, TSNode node) {
array_clear(&self->states);
array_clear(&self->finished_states);
ts_tree_cursor_reset(&self->cursor, node);
capture_list_pool_clear(&self->capture_list_pool);
self->depth = 0;
self->ascending = false;
}
bool ts_query_context_next(TSQueryContext *self) {
if (self->finished_states.size > 0) {
array_pop(&self->finished_states);
}
while (self->finished_states.size == 0) {
if (self->ascending) {
// Remove any states that were started within this node and are still
// not complete.
uint32_t deleted_count = 0;
for (unsigned i = 0, n = self->states.size; i < n; i++) {
QueryState *state = &self->states.contents[i];
if (state->start_depth == self->depth) {
// printf("FAIL STATE pattern: %u, step: %u\n", state->pattern_index, state->step_index);
capture_list_pool_release(
&self->capture_list_pool,
state->capture_list_id
);
deleted_count++;
} else if (deleted_count > 0) {
self->states.contents[i - deleted_count] = *state;
}
}
// if (deleted_count) {
// printf("FAILED %u of %u states\n", deleted_count, self->states.size);
// }
self->states.size -= deleted_count;
if (ts_tree_cursor_goto_next_sibling(&self->cursor)) {
self->ascending = false;
} else if (ts_tree_cursor_goto_parent(&self->cursor)) {
self->depth--;
} else {
return false;
}
} else {
TSFieldId field_id = NONE;
bool field_occurs_in_later_sibling = false;
TSNode node = ts_tree_cursor_current_node(&self->cursor);
TSSymbol symbol = ts_node_symbol(node);
// printf("DESCEND INTO NODE: %s\n", ts_node_type(node));
// Add new states for any patterns whose root node is a wildcard.
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
PatternSlice *slice = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[slice->step_index];
// Check that the node matches the criteria for the first step
// of the pattern.
if (step->field) {
if (field_id == NONE) {
field_id = ts_tree_cursor_current_field_id_ext(
&self->cursor,
&field_occurs_in_later_sibling
);
}
if (field_id != step->field) continue;
}
// Add a new state at the start of this pattern.
uint32_t capture_list_id = capture_list_pool_acquire(
&self->capture_list_pool
);
if (capture_list_id == NONE) break;
array_push(&self->states, ((QueryState) {
.step_index = slice->step_index,
.pattern_index = slice->pattern_index,
.capture_list_id = capture_list_id,
}));
}
// Add new states for any patterns whose root node matches this node.
unsigned i;
if (ts_query__pattern_map_search(self->query, symbol, &i)) {
PatternSlice *slice = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[slice->step_index];
do {
if (step->field) {
if (field_id == NONE) {
field_id = ts_tree_cursor_current_field_id_ext(
&self->cursor,
&field_occurs_in_later_sibling
);
}
if (field_id != step->field) continue;
}
// printf("START NEW STATE: %u\n", slice->pattern_index);
// If the node matches the first step of the pattern, then add
// a new in-progress state. First, acquire a list to hold the
// pattern's captures.
uint32_t capture_list_id = capture_list_pool_acquire(
&self->capture_list_pool
);
if (capture_list_id == NONE) break;
array_push(&self->states, ((QueryState) {
.pattern_index = slice->pattern_index,
.step_index = slice->step_index + 1,
.start_depth = self->depth,
.capture_list_id = capture_list_id,
.capture_count = 0,
}));
i++;
if (i == self->query->pattern_map.size) break;
slice = &self->query->pattern_map.contents[i];
step = &self->query->steps.contents[slice->step_index];
} while (step->symbol == symbol);
}
// Update all of the in-progress states with current node.
for (unsigned i = 0, n = self->states.size; i < n; i++) {
QueryState *state = &self->states.contents[i];
QueryStep *step = &self->query->steps.contents[state->step_index];
// Check that the node matches all of the criteria for the next
// step of the pattern.
if (state->start_depth + step->depth != self->depth) continue;
if (step->symbol && step->symbol != symbol) continue;
if (step->field) {
// Only compute the current field if it is needed for the current
// step of some in-progress pattern.
if (field_id == NONE) {
field_id = ts_tree_cursor_current_field_id_ext(
&self->cursor,
&field_occurs_in_later_sibling
);
}
if (field_id != step->field) continue;
}
// Some patterns can match their root node in multiple ways,
// capturing different children. If this pattern step could match
// later children within the same parent, then this query state
// cannot simply be updated in place. It must be split into two
// states: one that captures this node, and one which skips over
// this node, to preserve the possibility of capturing later
// siblings.
QueryState *next_state = state;
if (step->depth > 0 && (!step->field || field_occurs_in_later_sibling)) {
uint32_t capture_list_id = capture_list_pool_acquire(
&self->capture_list_pool
);
if (capture_list_id != NONE) {
array_push(&self->states, *state);
next_state = array_back(&self->states);
next_state->capture_list_id = capture_list_id;
}
}
// Record captures
if (step->capture_id != NONE) {
// printf("CAPTURE id: %u\n", step->capture_id);
TSQueryCapture *capture_list = capture_list_pool_get(
&self->capture_list_pool,
next_state->capture_list_id
);
capture_list[next_state->capture_count++] = (TSQueryCapture) {
node,
step->capture_id
};
}
// If the pattern is now done, then populate the query context's
// finished state.
next_state->step_index++;
QueryStep *next_step = step + 1;
if (next_step->depth == PATTERN_DONE_MARKER) {
// printf("FINISHED MATCH pattern: %u\n", next_state->pattern_index);
array_push(&self->finished_states, *next_state);
if (next_state == state) {
array_erase(&self->states, i);
i--;
n--;
} else {
array_pop(&self->states);
}
}
}
if (ts_tree_cursor_goto_first_child(&self->cursor)) {
self->depth++;
} else {
self->ascending = true;
}
}
}
return true;
}
uint32_t ts_query_context_matched_pattern_index(const TSQueryContext *self) {
if (self->finished_states.size > 0) {
QueryState *state = array_back(&self->finished_states);
return state->pattern_index;
}
return 0;
}
const TSQueryCapture *ts_query_context_matched_captures(
const TSQueryContext *self,
uint32_t *count
) {
if (self->finished_states.size > 0) {
QueryState *state = array_back(&self->finished_states);
*count = state->capture_count;
return capture_list_pool_get(
(CaptureListPool *)&self->capture_list_pool,
state->capture_list_id
);
}
return NULL;
}

View file

@ -244,7 +244,12 @@ TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) {
);
}
TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
static inline TSFieldId ts_tree_cursor__current_field_info(
const TSTreeCursor *_self,
const TSFieldMapEntry **field_map,
const TSFieldMapEntry **field_map_end,
uint32_t *child_index
) {
const TreeCursor *self = (const TreeCursor *)_self;
// Walk up the tree, visiting the current node and its invisible ancestors.
@ -264,25 +269,61 @@ TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
}
}
const TSFieldMapEntry *field_map, *field_map_end;
if (ts_subtree_extra(*entry->subtree)) break;
ts_language_field_map(
self->tree->language,
parent_entry->subtree->ptr->production_id,
&field_map, &field_map_end
field_map, field_map_end
);
while (field_map < field_map_end) {
if (
!ts_subtree_extra(*entry->subtree) &&
!field_map->inherited &&
field_map->child_index == entry->structural_child_index
) return field_map->field_id;
field_map++;
for (const TSFieldMapEntry *i = *field_map; i < *field_map_end; i++) {
if (!i->inherited && i->child_index == entry->structural_child_index) {
*child_index = entry->structural_child_index;
return i->field_id;
}
}
}
return 0;
}
TSFieldId ts_tree_cursor_current_field_id_ext(
const TSTreeCursor *self,
bool *field_has_additional
) {
uint32_t child_index;
const TSFieldMapEntry *field_map, *field_map_end;
TSFieldId field_id = ts_tree_cursor__current_field_info(
self,
&field_map,
&field_map_end,
&child_index
);
// After finding the field, check if any other later children have
// the same field name.
if (field_id) {
for (const TSFieldMapEntry *i = field_map; i < field_map_end; i++) {
if (i->field_id == field_id && i->child_index > child_index) {
*field_has_additional = true;
}
}
}
return field_id;
}
TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *self) {
uint32_t child_index;
const TSFieldMapEntry *field_map, *field_map_end;
return ts_tree_cursor__current_field_info(
self,
&field_map,
&field_map_end,
&child_index
);
}
const char *ts_tree_cursor_current_field_name(const TSTreeCursor *_self) {
TSFieldId id = ts_tree_cursor_current_field_id(_self);
if (id) {

View file

@ -16,5 +16,6 @@ typedef struct {
} TreeCursor;
void ts_tree_cursor_init(TreeCursor *, TSNode);
TSFieldId ts_tree_cursor_current_field_id_ext(const TSTreeCursor *, bool *);
#endif // TREE_SITTER_TREE_CURSOR_H_