Add wasm binding for running tree queries in a limited range

This commit is contained in:
Max Brunsfeld 2019-09-11 14:44:49 -07:00
parent 0528ad5f58
commit 49ce2fddb9
6 changed files with 143 additions and 24 deletions

View file

@ -56,8 +56,12 @@ fn test_query_errors_on_invalid_symbols() {
Err(QueryError::NodeType("non_existent3"))
);
assert_eq!(
Query::new(language, "(if_statement not_a_field: (identifier))"),
Err(QueryError::Field("not_a_field"))
Query::new(language, "(if_statement condit: (identifier))"),
Err(QueryError::Field("condit"))
);
assert_eq!(
Query::new(language, "(if_statement conditioning: (identifier))"),
Err(QueryError::Field("conditioning"))
);
});
}
@ -368,6 +372,67 @@ fn test_query_exec_within_byte_range() {
});
}
#[test]
fn test_query_exec_different_queries() {
allocations::record(|| {
let language = get_language("javascript");
let query1 = Query::new(
language,
"
(array (identifier) @id1)
",
)
.unwrap();
let query2 = Query::new(
language,
"
(array (identifier) @id1)
(pair (identifier) @id2)
",
)
.unwrap();
let query3 = Query::new(
language,
"
(array (identifier) @id1)
(pair (identifier) @id2)
(parenthesized_expression (identifier) @id3)
",
)
.unwrap();
let source = "[a, {b: b}, (c)];";
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let matches = cursor.exec(&query1, tree.root_node());
assert_eq!(
collect_matches(matches, &query1, source),
&[(0, vec![("id1", "a")]),]
);
let matches = cursor.exec(&query3, tree.root_node());
assert_eq!(
collect_matches(matches, &query3, source),
&[
(0, vec![("id1", "a")]),
(1, vec![("id2", "b")]),
(2, vec![("id3", "c")]),
]
);
let matches = cursor.exec(&query2, tree.root_node());
assert_eq!(
collect_matches(matches, &query2, source),
&[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),]
);
});
}
#[test]
fn test_query_capture_names() {
allocations::record(|| {

View file

@ -567,15 +567,25 @@ int ts_node_is_missing_wasm(const TSTree *tree) {
/* Section - Query */
/******************/
void ts_query_exec_wasm(const TSQuery *self, const TSTree *tree) {
void ts_query_exec_wasm(
const TSQuery *self,
const TSTree *tree,
uint32_t start_row,
uint32_t start_column,
uint32_t end_row,
uint32_t end_column
) {
if (!scratch_query_cursor) scratch_query_cursor = ts_query_cursor_new();
TSNode node = unmarshal_node(tree);
TSPoint start_point = {start_row, code_unit_to_byte(start_column)};
TSPoint end_point = {end_row, code_unit_to_byte(end_column)};
Array(const void *) result = array_new();
unsigned index = 0;
unsigned match_count = 0;
ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point);
ts_query_cursor_exec(scratch_query_cursor, self, node);
while (ts_query_cursor_next(scratch_query_cursor)) {
match_count++;
@ -586,7 +596,7 @@ void ts_query_exec_wasm(const TSQuery *self, const TSTree *tree) {
&capture_count
);
array_grow_by(&result, 1 + 6 * capture_count);
array_grow_by(&result, 2 + 6 * capture_count);
result.contents[index++] = (const void *)pattern_index;
result.contents[index++] = (const void *)capture_count;

View file

@ -688,24 +688,30 @@ class Language {
const nameLength = getValue(TRANSFER_BUFFER, 'i32');
captureNames[i] = UTF8ToString(nameAddress, nameLength);
}
C._free(sourceAddress);
return new Query(INTERNAL, address, captureNames);
} else {
const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const utf8ErrorOffset = getValue(TRANSFER_BUFFER, 'i32');
const errorOffset = UTF8ToString(sourceAddress, utf8ErrorOffset).length;
C._free(sourceAddress);
const suffix = source.slice(errorOffset, 100);
const errorByte = getValue(TRANSFER_BUFFER, 'i32');
const errorIndex = UTF8ToString(sourceAddress, errorByte).length;
const suffix = source.slice(errorIndex, 100);
const word = suffix.match(QUERY_WORD_REGEX)[0];
let error;
switch (errorId) {
case 2: throw new RangeError(
`Bad node name '${suffix.match(QUERY_WORD_REGEX)[0]}'`
);
case 3: throw new RangeError(
`Bad field name '${suffix.match(QUERY_WORD_REGEX)[0]}'`
);
default: throw new SyntaxError(
`Bad syntax at offset ${errorOffset}: '${suffix}'...`
);
case 2:
error = new RangeError(`Bad node name '${word}'`);
break;
case 3:
error = new RangeError(`Bad field name '${word}'`);
break;
default:
error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`);
break;
}
error.index = errorIndex;
error.length = word.length;
C._free(sourceAddress);
throw error;
}
}
@ -752,10 +758,20 @@ class Query {
C._ts_query_delete(this[0]);
}
exec(queryNode) {
exec(queryNode, startPosition, endPosition) {
if (!startPosition) startPosition = ZERO_POINT;
if (!endPosition) endPosition = ZERO_POINT;
marshalNode(queryNode);
C._ts_query_exec_wasm(this[0], queryNode.tree[0]);
C._ts_query_exec_wasm(
this[0],
queryNode.tree[0],
startPosition.row,
startPosition.column,
endPosition.row,
endPosition.column
);
const matchCount = getValue(TRANSFER_BUFFER, 'i32');
const nodesAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');

View file

@ -41,10 +41,7 @@ describe("Query", () => {
`);
const matches = query.exec(tree.rootNode);
assert.deepEqual(
matches.map(({pattern, captures}) => ({
pattern,
captures: captures.map(({name, node}) => ({name, text: node.text}))
})),
formatMatches(matches),
[
{pattern: 0, captures: [{name: 'fn-def', text: 'one'}]},
{pattern: 1, captures: [{name: 'fn-ref', text: 'two'}]},
@ -52,4 +49,33 @@ describe("Query", () => {
]
);
});
it('matches queries in specified ranges', () => {
tree = parser.parse("[a, b,\nc, d,\ne, f,\ng, h]");
query = JavaScript.query('(identifier) @element');
const matches = query.exec(
tree.rootNode,
{row: 1, column: 1},
{row: 3, column: 1}
);
assert.deepEqual(
formatMatches(matches),
[
{pattern: 0, captures: [{name: 'element', text: 'd'}]},
{pattern: 0, captures: [{name: 'element', text: 'e'}]},
{pattern: 0, captures: [{name: 'element', text: 'f'}]},
{pattern: 0, captures: [{name: 'element', text: 'g'}]},
]
);
});
});
function formatMatches(matches) {
return matches.map(({pattern, captures}) => ({
pattern,
captures: captures.map(({name, node}) => ({
name,
text: node.text
}))
}))
}

View file

@ -96,7 +96,8 @@ TSFieldId ts_language_field_id_for_name(
for (TSSymbol i = 1; i < count + 1; i++) {
switch (strncmp(name, self->field_names[i], name_length)) {
case 0:
return i;
if (self->field_names[i][name_length] == 0) return i;
break;
case -1:
return 0;
default:

View file

@ -783,6 +783,7 @@ bool ts_query_cursor_next(TSQueryCursor *self) {
.step_index = slice->step_index,
.pattern_index = slice->pattern_index,
.capture_list_id = capture_list_id,
.capture_count = 0,
}));
}