Add negated field patterns to queries

This commit is contained in:
Max Brunsfeld 2021-03-12 14:38:02 -08:00
parent 62a61c3540
commit bc0ef5f373
2 changed files with 215 additions and 0 deletions

View file

@ -132,6 +132,18 @@ fn test_query_errors_on_invalid_syntax() {
.join("\n")
);
// Need a field name after a negated field operator
assert_eq!(
Query::new(language, r#"(statement_block ! (if_statement))"#)
.unwrap_err()
.message,
[
r#"(statement_block ! (if_statement))"#,
r#" ^"#
]
.join("\n")
);
// tree-sitter/tree-sitter/issues/968
assert_eq!(
Query::new(get_language("c"), r#"(parameter_list [ ")" @foo)"#)
@ -201,6 +213,26 @@ fn test_query_errors_on_invalid_symbols() {
message: "conditioning".to_string()
}
);
assert_eq!(
Query::new(language, "(if_statement !alternativ)").unwrap_err(),
QueryError {
row: 0,
offset: 15,
column: 15,
kind: QueryErrorKind::Field,
message: "alternativ".to_string()
}
);
assert_eq!(
Query::new(language, "(if_statement !alternatives)").unwrap_err(),
QueryError {
row: 0,
offset: 15,
column: 15,
kind: QueryErrorKind::Field,
message: "alternatives".to_string()
}
);
});
}
@ -894,6 +926,71 @@ fn test_query_matches_with_last_named_child() {
});
}
#[test]
fn test_query_matches_with_negated_fields() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
(import_specifier
!alias
name: (identifier) @import_name)
(export_specifier
!alias
name: (identifier) @export_name)
(export_statement
!decorator
!source
(_) @exported)
; This negated field list is an extension of a previous
; negated field list. The order of the children and negated
; fields doesn't matter.
(export_statement
!decorator
!source
(_) @exported_expr
!declaration)
; This negated field list is a prefix of a previous
; negated field list.
(export_statement
!decorator
(_) @export_child .)
",
)
.unwrap();
assert_query_matches(
language,
&query,
"
import {a as b, c} from 'p1';
export {g, h as i} from 'p2';
@foo
export default 1;
export var j = 1;
export default k;
",
&[
(0, vec![("import_name", "c")]),
(1, vec![("export_name", "g")]),
(4, vec![("export_child", "'p2'")]),
(2, vec![("exported", "var j = 1;")]),
(4, vec![("export_child", "var j = 1;")]),
(2, vec![("exported", "k")]),
(3, vec![("exported_expr", "k")]),
(4, vec![("export_child", "k")]),
],
);
});
}
#[test]
fn test_query_matches_with_repeated_leaf_nodes() {
allocations::record(|| {

View file

@ -16,6 +16,7 @@
#define MAX_STEP_CAPTURE_COUNT 3
#define MAX_STATE_PREDECESSOR_COUNT 100
#define MAX_ANALYSIS_STATE_DEPTH 12
#define MAX_NEGATED_FIELD_COUNT 8
/*
* Stream - A sequence of unicode characters derived from a UTF8 string.
@ -75,6 +76,7 @@ typedef struct {
uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT];
uint16_t depth;
uint16_t alternative_index;
uint16_t negated_field_list_id;
bool contains_captures: 1;
bool is_immediate: 1;
bool is_last_child: 1;
@ -239,6 +241,7 @@ struct TSQuery {
Array(TSQueryPredicateStep) predicate_steps;
Array(QueryPattern) patterns;
Array(StepOffset) step_offsets;
Array(TSFieldId) negated_fields;
Array(char) string_buffer;
const TSLanguage *language;
uint16_t wildcard_root_pattern_count;
@ -480,6 +483,7 @@ static QueryStep query_step__new(
.field = 0,
.capture_ids = {NONE, NONE, NONE},
.alternative_index = NONE,
.negated_field_list_id = 0,
.contains_captures = false,
.is_last_child = false,
.is_pass_through = false,
@ -1366,6 +1370,58 @@ static void ts_query__finalize_steps(TSQuery *self) {
}
}
static void ts_query__add_negated_fields(
TSQuery *self,
uint16_t step_index,
TSFieldId *field_ids,
uint16_t field_count
) {
QueryStep *step = &self->steps.contents[step_index];
// The negated field array stores a list of field lists, separated by zeros.
// Try to find the start index of an existing list that matches this new list.
bool failed_match = false;
unsigned match_count = 0;
unsigned start_i = 0;
for (unsigned i = 0; i < self->negated_fields.size; i++) {
TSFieldId existing_field_id = self->negated_fields.contents[i];
// At each zero value, terminate the match attempt. If we've exactly
// matched the new field list, then reuse this index. Otherwise,
// start over the matching process.
if (existing_field_id == 0) {
if (match_count == field_count) {
step->negated_field_list_id = start_i;
return;
} else {
start_i = i + 1;
match_count = 0;
failed_match = false;
}
}
// If the existing list matches our new list so far, then advance
// to the next element of the new list.
else if (
match_count < field_count &&
existing_field_id == field_ids[match_count] &&
!failed_match
) {
match_count++;
}
// Otherwise, this existing list has failed to match.
else {
match_count = 0;
failed_match = true;
}
}
step->negated_field_list_id = self->negated_fields.size;
array_extend(&self->negated_fields, field_count, field_ids);
array_push(&self->negated_fields, 0);
}
static TSQueryError ts_query__parse_string_literal(
TSQuery *self,
Stream *stream
@ -1716,7 +1772,39 @@ static TSQueryError ts_query__parse_pattern(
// Parse the child patterns
bool child_is_immediate = false;
uint16_t last_child_step_index = 0;
uint16_t negated_field_count = 0;
TSFieldId negated_field_ids[MAX_NEGATED_FIELD_COUNT];
for (;;) {
// Parse a negated field assertion
if (stream->next == '!') {
stream_advance(stream);
stream_skip_whitespace(stream);
if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax;
const char *field_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - field_name;
stream_skip_whitespace(stream);
TSFieldId field_id = ts_language_field_id_for_name(
self->language,
field_name,
length
);
if (!field_id) {
stream->input = field_name;
return TSQueryErrorField;
}
// Keep the field ids sorted.
if (negated_field_count < MAX_NEGATED_FIELD_COUNT) {
negated_field_ids[negated_field_count] = field_id;
negated_field_count++;
}
continue;
}
// Parse a sibling anchor
if (stream->next == '.') {
child_is_immediate = true;
stream_advance(stream);
@ -1737,6 +1825,16 @@ static TSQueryError ts_query__parse_pattern(
}
self->steps.contents[last_child_step_index].is_last_child = true;
}
if (negated_field_count) {
ts_query__add_negated_fields(
self,
starting_step_index,
negated_field_ids,
negated_field_count
);
}
stream_advance(stream);
break;
} else if (e) {
@ -1945,10 +2043,13 @@ TSQuery *ts_query_new(
.patterns = array_new(),
.step_offsets = array_new(),
.string_buffer = array_new(),
.negated_fields = array_new(),
.wildcard_root_pattern_count = 0,
.language = language,
};
array_push(&self->negated_fields, 0);
// Parse all of the S-expressions in the given string.
Stream stream = stream_new(source, source_len);
stream_skip_whitespace(&stream);
@ -2033,6 +2134,7 @@ void ts_query_delete(TSQuery *self) {
array_delete(&self->patterns);
array_delete(&self->step_offsets);
array_delete(&self->string_buffer);
array_delete(&self->negated_fields);
symbol_table_delete(&self->captures);
symbol_table_delete(&self->predicate_values);
ts_free(self);
@ -2700,6 +2802,22 @@ static inline bool ts_query_cursor__advance(
}
}
if (step->negated_field_list_id) {
TSFieldId *negated_field_ids = &self->query->negated_fields.contents[step->negated_field_list_id];
for (;;) {
TSFieldId negated_field_id = *negated_field_ids;
if (negated_field_id) {
negated_field_ids++;
if (ts_node_child_by_field_id(node, negated_field_id).id) {
node_does_match = false;
break;
}
} else {
break;
}
}
}
// Remove states immediately if it is ever clear that they cannot match.
if (!node_does_match) {
if (!later_sibling_can_match) {