Allow predicates in queries, to match on nodes' text

This commit is contained in:
Max Brunsfeld 2019-09-15 22:06:51 -07:00
parent 307a1a6c11
commit 096126d039
8 changed files with 781 additions and 186 deletions

View file

@ -30,13 +30,17 @@ typedef struct {
} QueryStep;
/*
* CaptureSlice - The name of a capture, represented as a slice of a
* shared string.
* Slice - A string represented as a slice of a shared string.
*/
typedef struct {
uint32_t offset;
uint32_t length;
} CaptureSlice;
} Slice;
typedef struct {
Array(char) characters;
Array(Slice) slices;
} SymbolTable;
/*
* PatternSlice - The set of steps needed to match a particular pattern,
@ -60,6 +64,7 @@ typedef struct {
uint8_t capture_count;
uint8_t capture_list_id;
uint8_t consumed_capture_count;
uint32_t id;
} QueryState;
/*
@ -73,6 +78,17 @@ typedef struct {
uint32_t usage_map;
} CaptureListPool;
typedef enum {
PredicateStepTypeSymbol,
PredicateStepTypeCapture,
PredicateStepTypeDone,
} PredicateStepType;
typedef struct {
bool is_capture;
uint16_t value_id;
} PredicateStep;
/*
* TSQuery - A tree query, compiled from a string of S-expressions. The query
* itself is immutable. The mutable state used in the process of executing the
@ -80,9 +96,11 @@ typedef struct {
*/
struct TSQuery {
Array(QueryStep) steps;
Array(char) capture_data;
Array(CaptureSlice) capture_names;
SymbolTable captures;
SymbolTable predicate_values;
Array(PatternSlice) pattern_map;
Array(TSQueryPredicateStep) predicate_steps;
Array(Slice) predicates_by_pattern;
const TSLanguage *language;
uint16_t max_capture_count;
uint16_t wildcard_root_pattern_count;
@ -100,6 +118,7 @@ struct TSQueryCursor {
uint32_t depth;
uint32_t start_byte;
uint32_t end_byte;
uint32_t next_state_id;
TSPoint start_point;
TSPoint end_point;
bool ascending;
@ -177,7 +196,9 @@ static void stream_scan_identifier(Stream *stream) {
iswalnum(stream->next) ||
stream->next == '_' ||
stream->next == '-' ||
stream->next == '.'
stream->next == '.' ||
stream->next == '?' ||
stream->next == '!'
);
}
@ -222,6 +243,65 @@ static void capture_list_pool_release(CaptureListPool *self, uint16_t id) {
self->usage_map |= bitmask_for_index(id);
}
/**************
* SymbolTable
**************/
static SymbolTable symbol_table_new() {
return (SymbolTable) {
.characters = array_new(),
.slices = array_new(),
};
}
static void symbol_table_delete(SymbolTable *self) {
array_delete(&self->characters);
array_delete(&self->slices);
}
static int symbol_table_id_for_name(
const SymbolTable *self,
const char *name,
uint32_t length
) {
for (unsigned i = 0; i < self->slices.size; i++) {
Slice slice = self->slices.contents[i];
if (
slice.length == length &&
!strncmp(&self->characters.contents[slice.offset], name, length)
) return i;
}
return -1;
}
static const char *symbol_table_name_for_id(
const SymbolTable *self,
uint16_t id,
uint32_t *length
) {
Slice slice = self->slices.contents[id];
*length = slice.length;
return &self->characters.contents[slice.offset];
}
static uint16_t symbol_table_insert_name(
SymbolTable *self,
const char *name,
uint32_t length
) {
int id = symbol_table_id_for_name(self, name, length);
if (id >= 0) return (uint16_t)id;
Slice slice = {
.offset = self->characters.size,
.length = length,
};
array_grow_by(&self->characters, length + 1);
memcpy(&self->characters.contents[slice.offset], name, length);
self->characters.contents[self->characters.size - 1] = 0;
array_push(&self->slices, slice);
return self->slices.size - 1;
}
/*********
* Query
*********/
@ -241,24 +321,6 @@ static TSSymbol ts_query_intern_node_name(
return 0;
}
static uint16_t ts_query_intern_capture_name(
TSQuery *self,
const char *name,
uint32_t length
) {
int id = ts_query_capture_id_for_name(self, name, length);
if (id >= 0) return (uint16_t)id;
CaptureSlice capture = {
.offset = self->capture_data.size,
.length = length,
};
array_grow_by(&self->capture_data, length + 1);
memcpy(&self->capture_data.contents[capture.offset], name, length);
self->capture_data.contents[self->capture_data.size - 1] = 0;
array_push(&self->capture_names, capture);
return self->capture_names.size - 1;
}
// The `pattern_map` contains a mapping from TSSymbol values to indices in the
// `steps` array. For a given syntax node, the `pattern_map` makes it possible
// to quickly find the starting steps of all of the patterns whose root matches
@ -322,6 +384,110 @@ static inline void ts_query__pattern_map_insert(
}));
}
static TSQueryError ts_query_parse_predicate(
TSQuery *self,
Stream *stream
) {
if (stream->next == ')') return PARENT_DONE;
if (stream->next != '(') return TSQueryErrorSyntax;
stream_advance(stream);
stream_skip_whitespace(stream);
unsigned step_count = 0;
for (;;) {
if (stream->next == ')') {
stream_advance(stream);
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeDone,
.value_id = 0,
}));
break;
}
// Parse an `@`-prefixed capture
else if (stream->next == '@') {
stream_advance(stream);
stream_skip_whitespace(stream);
// Parse the capture name
if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax;
const char *capture_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - capture_name;
// Add the capture id to the first step of the pattern
int capture_id = symbol_table_id_for_name(
&self->captures,
capture_name,
length
);
if (capture_id == -1) {
stream_reset(stream, capture_name);
return TSQueryErrorCapture;
}
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeCapture,
.value_id = capture_id,
}));
}
// Parse a string literal
else if (stream->next == '"') {
stream_advance(stream);
// Parse the string content
const char *string_content = stream->input;
while (stream->next != '"') {
if (!stream_advance(stream)) {
stream_reset(stream, string_content - 1);
return TSQueryErrorSyntax;
}
}
uint32_t length = stream->input - string_content;
// Add a step for the node
uint16_t id = symbol_table_insert_name(
&self->predicate_values,
string_content,
length
);
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeString,
.value_id = id,
}));
if (stream->next != '"') return TSQueryErrorSyntax;
stream_advance(stream);
}
// Parse a bare symbol
else if (stream_is_ident_start(stream)) {
const char *symbol_start = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - symbol_start;
uint16_t id = symbol_table_insert_name(
&self->predicate_values,
symbol_start,
length
);
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeString,
.value_id = id,
}));
}
step_count++;
stream_skip_whitespace(stream);
}
return 0;
}
// Read one S-expression pattern from the stream, and incorporate it into
// the query's internal state machine representation. For nested patterns,
// this function calls itself recursively.
@ -344,6 +510,26 @@ static TSQueryError ts_query_parse_pattern(
else if (stream->next == '(') {
stream_advance(stream);
stream_skip_whitespace(stream);
// Parse a pattern inside of a conditional form
if (stream->next == '(' && depth == 0) {
TSQueryError e = ts_query_parse_pattern(self, stream, 0, capture_count);
if (e) return e;
// Parse the child patterns
stream_skip_whitespace(stream);
for (;;) {
TSQueryError e = ts_query_parse_predicate(self, stream);
if (e == PARENT_DONE) {
stream_advance(stream);
stream_skip_whitespace(stream);
return 0;
} else if (e) {
return e;
}
}
}
TSSymbol symbol;
// Parse the wildcard symbol
@ -494,8 +680,8 @@ static TSQueryError ts_query_parse_pattern(
uint32_t length = stream->input - capture_name;
// Add the capture id to the first step of the pattern
uint16_t capture_id = ts_query_intern_capture_name(
self,
uint16_t capture_id = symbol_table_insert_name(
&self->captures,
capture_name,
length
);
@ -519,6 +705,10 @@ TSQuery *ts_query_new(
*self = (TSQuery) {
.steps = array_new(),
.pattern_map = array_new(),
.captures = symbol_table_new(),
.predicate_values = symbol_table_new(),
.predicate_steps = array_new(),
.predicates_by_pattern = array_new(),
.wildcard_root_pattern_count = 0,
.max_capture_count = 0,
.language = language,
@ -531,6 +721,10 @@ TSQuery *ts_query_new(
for (;;) {
start_step_index = self->steps.size;
uint32_t capture_count = 0;
array_push(&self->predicates_by_pattern, ((Slice) {
.offset = self->predicate_steps.size,
.length = 0,
}));
*error_type = ts_query_parse_pattern(self, &stream, 0, &capture_count);
array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER }));
@ -569,14 +763,24 @@ void ts_query_delete(TSQuery *self) {
if (self) {
array_delete(&self->steps);
array_delete(&self->pattern_map);
array_delete(&self->capture_data);
array_delete(&self->capture_names);
array_delete(&self->predicate_steps);
array_delete(&self->predicates_by_pattern);
symbol_table_delete(&self->captures);
symbol_table_delete(&self->predicate_values);
ts_free(self);
}
}
uint32_t ts_query_pattern_count(const TSQuery *self) {
return self->predicates_by_pattern.size;
}
uint32_t ts_query_capture_count(const TSQuery *self) {
return self->capture_names.size;
return self->captures.slices.size;
}
uint32_t ts_query_string_count(const TSQuery *self) {
return self->predicate_values.slices.size;
}
const char *ts_query_capture_name_for_id(
@ -584,9 +788,15 @@ const char *ts_query_capture_name_for_id(
uint32_t index,
uint32_t *length
) {
CaptureSlice name = self->capture_names.contents[index];
*length = name.length;
return &self->capture_data.contents[name.offset];
return symbol_table_name_for_id(&self->captures, index, length);
}
const char *ts_query_string_value_for_id(
const TSQuery *self,
uint32_t index,
uint32_t *length
) {
return symbol_table_name_for_id(&self->predicate_values, index, length);
}
int ts_query_capture_id_for_name(
@ -594,14 +804,25 @@ int ts_query_capture_id_for_name(
const char *name,
uint32_t length
) {
for (unsigned i = 0; i < self->capture_names.size; i++) {
CaptureSlice existing = self->capture_names.contents[i];
if (
existing.length == length &&
!strncmp(&self->capture_data.contents[existing.offset], name, length)
) return i;
}
return -1;
return symbol_table_id_for_name(&self->captures, name, length);
}
int ts_query_string_id_for_value(
const TSQuery *self,
const char *value,
uint32_t length
) {
return symbol_table_id_for_name(&self->predicate_values, value, length);
}
const TSQueryPredicateStep *ts_query_predicates_for_pattern(
const TSQuery *self,
uint32_t pattern_index,
uint32_t *step_count
) {
Slice slice = self->predicates_by_pattern.contents[pattern_index];
*step_count = slice.length;
return &self->predicate_steps.contents[slice.offset];
}
/***************
@ -640,6 +861,7 @@ void ts_query_cursor_exec(
array_clear(&self->finished_states);
ts_tree_cursor_reset(&self->cursor, node);
capture_list_pool_reset(&self->capture_list_pool, query->max_capture_count);
self->next_state_id = 0;
self->depth = 0;
self->ascending = false;
self->query = query;
@ -891,6 +1113,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
if (next_step->depth == PATTERN_DONE_MARKER) {
LOG("finish pattern %u\n", next_state->pattern_index);
next_state->id = self->next_state_id++;
array_push(&self->finished_states, *next_state);
if (next_state == state) {
array_erase(&self->states, i);
@ -915,9 +1138,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
bool ts_query_cursor_next_match(
TSQueryCursor *self,
uint32_t *pattern_index,
uint32_t *capture_count,
const TSQueryCapture **captures
TSQueryMatch *match
) {
if (self->finished_states.size > 0) {
QueryState state = array_pop(&self->finished_states);
@ -927,9 +1148,10 @@ bool ts_query_cursor_next_match(
if (!ts_query_cursor__advance(self)) return false;
const QueryState *state = array_back(&self->finished_states);
*pattern_index = state->pattern_index;
*capture_count = state->capture_count;
*captures = capture_list_pool_get(
match->id = state->id;
match->pattern_index = state->pattern_index;
match->capture_count = state->capture_count;
match->captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
@ -939,7 +1161,8 @@ bool ts_query_cursor_next_match(
bool ts_query_cursor_next_capture(
TSQueryCursor *self,
TSQueryCapture *capture
TSQueryMatch *match,
uint32_t *capture_index
) {
for (;;) {
if (self->finished_states.size > 0) {
@ -991,19 +1214,15 @@ bool ts_query_cursor_next_capture(
QueryState *state = &self->finished_states.contents[
first_finished_state_index
];
const TSQueryCapture *captures = capture_list_pool_get(
match->id = state->id;
match->pattern_index = state->pattern_index;
match->capture_count = state->capture_count;
match->captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
*capture = captures[state->consumed_capture_count];
*capture_index = state->consumed_capture_count;
state->consumed_capture_count++;
if (state->consumed_capture_count == state->capture_count) {
capture_list_pool_release(
&self->capture_list_pool,
state->capture_list_id
);
array_erase(&self->finished_states, first_finished_state_index);
}
return true;
}
}