query: Differentiate between wildcard '*' and named wildcard '(*)'

This commit is contained in:
Max Brunsfeld 2020-02-19 09:42:29 -08:00
parent 1d6ea51b63
commit 950a89a525
2 changed files with 36 additions and 2 deletions

View file

@ -402,6 +402,38 @@ fn test_query_matches_capturing_error_nodes() {
});
}
#[test]
fn test_query_matches_with_named_wildcard() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
(return_statement (*) @the-return-value)
(binary_expression operator: * @the-operator)
",
)
.unwrap();
let source = "return a + b - c;";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("the-return-value", "a + b - c")]),
(1, vec![("the-operator", "+")]),
(1, vec![("the-operator", "-")]),
]
);
});
}
#[test]
fn test_query_matches_in_language_with_simple_aliases() {
allocations::record(|| {

View file

@ -144,6 +144,7 @@ static const TSQueryError PARENT_DONE = -1;
static const uint8_t PATTERN_DONE_MARKER = UINT8_MAX;
static const uint16_t NONE = UINT16_MAX;
static const TSSymbol WILDCARD_SYMBOL = 0;
static const TSSymbol NAMED_WILDCARD_SYMBOL = UINT16_MAX - 1;
static const uint16_t MAX_STATE_COUNT = 32;
// #define LOG(...) fprintf(stderr, __VA_ARGS__)
@ -615,7 +616,7 @@ static TSQueryError ts_query__parse_pattern(
// Parse the wildcard symbol
if (stream->next == '*') {
symbol = WILDCARD_SYMBOL;
symbol = NAMED_WILDCARD_SYMBOL;
stream_advance(stream);
}
@ -1240,7 +1241,8 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
// pattern.
bool node_does_match =
step->symbol == symbol ||
(!step->symbol && ts_node_is_named(node));
step->symbol == WILDCARD_SYMBOL ||
(step->symbol == NAMED_WILDCARD_SYMBOL && ts_node_is_named(node));
bool later_sibling_can_match = can_have_later_siblings;
if (step->field) {
if (step->field == field_id) {