Add wasm bindings for predicates

This commit is contained in:
Max Brunsfeld 2019-09-16 10:25:44 -07:00
parent 096126d039
commit d4d554b2ae
6 changed files with 249 additions and 41 deletions

View file

@ -36,6 +36,15 @@ fn test_query_errors_on_invalid_syntax() {
Query::new(language, r#"(identifier) "h "#), Query::new(language, r#"(identifier) "h "#),
Err(QueryError::Syntax(13)) Err(QueryError::Syntax(13))
); );
assert_eq!(
Query::new(language, r#"((identifier) ()"#),
Err(QueryError::Syntax(16))
);
assert_eq!(
Query::new(language, r#"((identifier) @x (eq? @x a"#),
Err(QueryError::Syntax(26))
);
}); });
} }

View file

@ -155,9 +155,7 @@ pub struct QueryCursor(*mut ffi::TSQueryCursor);
pub struct QueryMatch<'a> { pub struct QueryMatch<'a> {
pattern_index: usize, pattern_index: usize,
capture_count: usize, captures: &'a [ffi::TSQueryCapture],
captures_ptr: *const ffi::TSQueryCapture,
cursor: PhantomData<&'a ()>,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
@ -1140,24 +1138,32 @@ impl QueryCursor {
&'a mut self, &'a mut self,
query: &'a Query, query: &'a Query,
node: Node<'a>, node: Node<'a>,
text_callback: impl FnMut(Node<'a>) -> &'a [u8], mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
) -> impl Iterator<Item = QueryMatch<'a>> + 'a { ) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
unsafe { unsafe {
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
} }
std::iter::from_fn(move || -> Option<QueryMatch<'a>> { std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
unsafe { loop {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit(); unsafe {
if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) { let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
let m = m.assume_init(); if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) {
Some(QueryMatch { let m = m.assume_init();
pattern_index: m.pattern_index as usize, let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
capture_count: m.capture_count as usize, if self.captures_match_condition(
captures_ptr: m.captures, query,
cursor: PhantomData, captures,
}) m.pattern_index as usize,
} else { &mut text_callback,
None ) {
return Some(QueryMatch {
pattern_index: m.pattern_index as usize,
captures,
});
}
} else {
return None;
}
} }
} }
}) })
@ -1260,9 +1266,7 @@ impl<'a> QueryMatch<'a> {
} }
pub fn captures(&self) -> impl ExactSizeIterator<Item = (usize, Node)> { pub fn captures(&self) -> impl ExactSizeIterator<Item = (usize, Node)> {
let captures = self.captures
unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) };
captures
.iter() .iter()
.map(|capture| (capture.index as usize, Node::new(capture.node).unwrap())) .map(|capture| (capture.index as usize, Node::new(capture.node).unwrap()))
} }

View file

@ -7,6 +7,10 @@ const SIZE_OF_RANGE = 2 * SIZE_OF_INT + 2 * SIZE_OF_POINT;
const ZERO_POINT = {row: 0, column: 0}; const ZERO_POINT = {row: 0, column: 0};
const QUERY_WORD_REGEX = /[\w-.]*/g; const QUERY_WORD_REGEX = /[\w-.]*/g;
const PREDICATE_STEP_TYPE_DONE = 0;
const PREDICATE_STEP_TYPE_CAPTURE = 1;
const PREDICATE_STEP_TYPE_STRING = 2;
var VERSION; var VERSION;
var MIN_COMPATIBLE_VERSION; var MIN_COMPATIBLE_VERSION;
var TRANSFER_BUFFER; var TRANSFER_BUFFER;
@ -661,21 +665,8 @@ class Language {
TRANSFER_BUFFER, TRANSFER_BUFFER,
TRANSFER_BUFFER + SIZE_OF_INT TRANSFER_BUFFER + SIZE_OF_INT
); );
if (address) {
const captureCount = C._ts_query_capture_count(address); if (!address) {
const captureNames = new Array(captureCount);
for (let i = 0; i < captureCount; i++) {
const nameAddress = C._ts_query_capture_name_for_id(
address,
i,
TRANSFER_BUFFER
);
const nameLength = getValue(TRANSFER_BUFFER, 'i32');
captureNames[i] = UTF8ToString(nameAddress, nameLength);
}
C._free(sourceAddress);
return new Query(INTERNAL, address, captureNames);
} else {
const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const errorByte = getValue(TRANSFER_BUFFER, 'i32'); const errorByte = getValue(TRANSFER_BUFFER, 'i32');
const errorIndex = UTF8ToString(sourceAddress, errorByte).length; const errorIndex = UTF8ToString(sourceAddress, errorByte).length;
@ -689,6 +680,9 @@ class Language {
case 3: case 3:
error = new RangeError(`Bad field name '${word}'`); error = new RangeError(`Bad field name '${word}'`);
break; break;
case 4:
error = new RangeError(`Bad capture name @${word}`);
break;
default: default:
error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`); error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`);
break; break;
@ -698,6 +692,63 @@ class Language {
C._free(sourceAddress); C._free(sourceAddress);
throw error; throw error;
} }
const stringCount = C._ts_query_string_count(address);
const captureCount = C._ts_query_capture_count(address);
const patternCount = C._ts_query_pattern_count(address);
const captureNames = new Array(captureCount);
const stringValues = new Array(stringCount);
for (let i = 0; i < captureCount; i++) {
const nameAddress = C._ts_query_capture_name_for_id(
address,
i,
TRANSFER_BUFFER
);
const nameLength = getValue(TRANSFER_BUFFER, 'i32');
captureNames[i] = UTF8ToString(nameAddress, nameLength);
}
for (let i = 0; i < stringCount; i++) {
const valueAddress = C._ts_query_string_value_for_id(
address,
i,
TRANSFER_BUFFER
);
const nameLength = getValue(TRANSFER_BUFFER, 'i32');
stringValues[i] = UTF8ToString(valueAddress, nameLength);
}
const predicates = new Array(patternCount);
for (let i = 0; i < patternCount; i++) {
const predicatesAddress = C._ts_query_predicates_for_pattern(
address,
i,
TRANSFER_BUFFER
);
const stepCount = getValue(TRANSFER_BUFFER, 'i32');
predicates[i] = [];
const steps = [];
let stepAddress = predicatesAddress;
for (let j = 0; j < stepCount; j++) {
const stepType = getValue(stepAddress, 'i32');
stepAddress += SIZE_OF_INT;
const stepValueId = getValue(stepAddress, 'i32');
stepAddress += SIZE_OF_INT;
if (stepType === PREDICATE_STEP_TYPE_CAPTURE) {
steps.push({type: 'capture', name: captureNames[stepValueId]});
} else if (stepType === PREDICATE_STEP_TYPE_STRING) {
steps.push({type: 'string', value: stringValues[stepValueId]});
} else if (steps.length > 0) {
predicates[i].push(buildQueryPredicate(steps));
steps.length = 0;
}
}
}
C._free(sourceAddress);
return new Query(INTERNAL, address, captureNames, predicates);
} }
static load(url) { static load(url) {
@ -733,10 +784,11 @@ class Language {
} }
class Query { class Query {
constructor(internal, address, captureNames) { constructor(internal, address, captureNames, predicates) {
assertInternal(internal); assertInternal(internal);
this[0] = address; this[0] = address;
this.captureNames = captureNames; this.captureNames = captureNames;
this.predicates = predicates;
} }
delete() { delete() {
@ -771,7 +823,9 @@ class Query {
const captures = new Array(captureCount); const captures = new Array(captureCount);
address = unmarshalCaptures(this, node.tree, address, captures); address = unmarshalCaptures(this, node.tree, address, captures);
result[i] = {pattern, captures}; if (this.predicates[pattern].every(p => p(captures))) {
result[i] = {pattern, captures};
}
} }
C._free(startAddress); C._free(startAddress);
@ -809,7 +863,7 @@ class Query {
const captures = new Array(captureCount); const captures = new Array(captureCount);
address = unmarshalCaptures(this, node.tree, address, captures); address = unmarshalCaptures(this, node.tree, address, captures);
if (capturesMatchConditions(this, node.tree, pattern, captures)) { if (this.predicates[pattern].every(p => p(captures))) {
result.push(captures[captureIndex]); result.push(captures[captureIndex]);
} }
} }
@ -819,8 +873,63 @@ class Query {
} }
} }
function capturesMatchConditions(query, tree, pattern, captures) { function buildQueryPredicate(steps) {
return true; if (steps[0].type !== 'string') {
throw new Error('Predicates must begin with a literal value');
}
switch (steps[0].value) {
case 'eq?':
if (steps.length !== 3) throw new Error(
`Wrong number of arguments to \`eq?\` predicate. Expected 2, got ${steps.length - 1}`
);
if (steps[1].type !== 'capture') throw new Error(
`First argument of \`eq?\` predicate must be a capture. Got "${steps[1].value}"`
);
if (steps[2].type === 'capture') {
const captureName1 = steps[1].name;
const captureName2 = steps[2].name;
return function(captures) {
let node1, node2
for (const c of captures) {
if (c.name === captureName1) node1 = c.node;
if (c.name === captureName2) node2 = c.node;
}
return node1.text === node2.text
}
} else {
const captureName = steps[1].name;
const stringValue = steps[2].value;
return function(captures) {
for (const c of captures) {
if (c.name === captureName) return c.node.text === stringValue;
}
return false;
}
}
case 'match?':
if (steps.length !== 3) throw new Error(
`Wrong number of arguments to \`match?\` predicate. Expected 2, got ${steps.length - 1}.`
);
if (steps[1].type !== 'capture') throw new Error(
`First argument of \`match?\` predicate must be a capture. Got "${steps[1].value}".`
);
if (steps[2].type !== 'string') throw new Error(
`Second argument of \`match?\` predicate must be a string. Got @${steps[2].value}.`
);
const captureName = steps[1].name;
const regex = new RegExp(steps[2].value);
return function(captures) {
for (const c of captures) {
if (c.name === captureName) return regex.test(c.node.text);
}
return false;
}
default:
throw new Error(`Unknown query predicate \`${steps[0].value}\``);
}
} }
function unmarshalCaptures(query, tree, address, result) { function unmarshalCaptures(query, tree, address, result) {

View file

@ -70,12 +70,16 @@
"_ts_parser_set_language", "_ts_parser_set_language",
"_ts_query_capture_count", "_ts_query_capture_count",
"_ts_query_capture_name_for_id", "_ts_query_capture_name_for_id",
"_ts_query_captures_wasm",
"_ts_query_context_delete", "_ts_query_context_delete",
"_ts_query_context_new", "_ts_query_context_new",
"_ts_query_delete", "_ts_query_delete",
"_ts_query_matches_wasm", "_ts_query_matches_wasm",
"_ts_query_captures_wasm",
"_ts_query_new", "_ts_query_new",
"_ts_query_pattern_count",
"_ts_query_predicates_for_pattern",
"_ts_query_string_count",
"_ts_query_string_value_for_id",
"_ts_tree_cursor_current_field_id_wasm", "_ts_tree_cursor_current_field_id_wasm",
"_ts_tree_cursor_current_node_id_wasm", "_ts_tree_cursor_current_node_id_wasm",
"_ts_tree_cursor_current_node_is_missing_wasm", "_ts_tree_cursor_current_node_is_missing_wasm",

View file

@ -19,7 +19,7 @@ describe("Query", () => {
}); });
describe('construction', () => { describe('construction', () => {
it('throws an error on invalid syntax', () => { it('throws an error on invalid patterns', () => {
assert.throws(() => { assert.throws(() => {
JavaScript.query("(function_declaration wat)") JavaScript.query("(function_declaration wat)")
}, "Bad syntax at offset 22: \'wat)\'..."); }, "Bad syntax at offset 22: \'wat)\'...");
@ -33,6 +33,24 @@ describe("Query", () => {
JavaScript.query("(function_declaration non_existent:(identifier))") JavaScript.query("(function_declaration non_existent:(identifier))")
}, "Bad field name 'non_existent'"); }, "Bad field name 'non_existent'");
}); });
it('throws an error on invalid predicates', () => {
assert.throws(() => {
JavaScript.query("((identifier) @abc (eq? @ab hi))")
}, "Bad capture name @ab");
assert.throws(() => {
JavaScript.query("((identifier) @abc (eq? @ab hi))")
}, "Bad capture name @ab");
assert.throws(() => {
JavaScript.query("((identifier) @abc (eq?))")
}, "Wrong number of arguments to `eq?` predicate. Expected 2, got 0");
assert.throws(() => {
JavaScript.query("((identifier) @a (eq? @a @a @a))")
}, "Wrong number of arguments to `eq?` predicate. Expected 2, got 3");
assert.throws(() => {
JavaScript.query("((identifier) @a (something-else? @a))")
}, "Unknown query predicate `something-else?`");
});
}); });
describe('.matches', () => { describe('.matches', () => {
@ -119,6 +137,66 @@ describe("Query", () => {
] ]
); );
}); });
it('handles conditions that compare the text of capture to literal strings', () => {
tree = parser.parse(`
const ab = require('./ab');
new Cd(EF);
`);
query = JavaScript.query(`
(identifier) @variable
((identifier) @function.builtin
(eq? @function.builtin "require"))
((identifier) @constructor
(match? @constructor "^[A-Z]"))
((identifier) @constant
(match? @constant "^[A-Z]{2,}$"))
`);
const captures = query.captures(tree.rootNode);
assert.deepEqual(
formatCaptures(captures),
[
{name: "variable", text: "ab"},
{name: "variable", text: "require"},
{name: "function.builtin", text: "require"},
{name: "variable", text: "Cd"},
{name: "constructor", text: "Cd"},
{name: "variable", text: "EF"},
{name: "constructor", text: "EF"},
{name: "constant", text: "EF"},
]
);
});
it('handles conditions that compare the text of capture to each other', () => {
tree = parser.parse(`
const ab = abc + 1;
const def = de + 1;
const ghi = ghi + 1;
`);
query = JavaScript.query(`
((variable_declarator
name: (identifier) @id1
value: (binary_expression
left: (identifier) @id2))
(eq? @id1 @id2))
`);
const captures = query.captures(tree.rootNode);
assert.deepEqual(
formatCaptures(captures),
[
{name: "id1", text: "ghi"},
{name: "id2", text: "ghi"},
]
);
});
}); });
}); });

View file

@ -481,6 +481,10 @@ static TSQueryError ts_query_parse_predicate(
})); }));
} }
else {
return TSQueryErrorSyntax;
}
step_count++; step_count++;
stream_skip_whitespace(stream); stream_skip_whitespace(stream);
} }