From d4d554b2aefcd4df019bcf996d1aee2220d894e5 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 16 Sep 2019 10:25:44 -0700 Subject: [PATCH] Add wasm bindings for predicates --- cli/src/tests/query_test.rs | 9 ++ lib/binding_rust/lib.rs | 42 ++++---- lib/binding_web/binding.js | 149 +++++++++++++++++++++++++---- lib/binding_web/exports.json | 6 +- lib/binding_web/test/query-test.js | 80 +++++++++++++++- lib/src/query.c | 4 + 6 files changed, 249 insertions(+), 41 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 91f4cc35..586ba7f9 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -36,6 +36,15 @@ fn test_query_errors_on_invalid_syntax() { Query::new(language, r#"(identifier) "h "#), Err(QueryError::Syntax(13)) ); + + assert_eq!( + Query::new(language, r#"((identifier) ()"#), + Err(QueryError::Syntax(16)) + ); + assert_eq!( + Query::new(language, r#"((identifier) @x (eq? @x a"#), + Err(QueryError::Syntax(26)) + ); }); } diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 17bb0ffe..7a21c4ef 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -155,9 +155,7 @@ pub struct QueryCursor(*mut ffi::TSQueryCursor); pub struct QueryMatch<'a> { pattern_index: usize, - capture_count: usize, - captures_ptr: *const ffi::TSQueryCapture, - cursor: PhantomData<&'a ()>, + captures: &'a [ffi::TSQueryCapture], } #[derive(Debug, PartialEq, Eq)] @@ -1140,24 +1138,32 @@ impl QueryCursor { &'a mut self, query: &'a Query, node: Node<'a>, - text_callback: impl FnMut(Node<'a>) -> &'a [u8], + mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, ) -> impl Iterator> + 'a { unsafe { ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); } std::iter::from_fn(move || -> Option> { - unsafe { - let mut m = MaybeUninit::::uninit(); - if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) { - let m = m.assume_init(); - Some(QueryMatch { - pattern_index: m.pattern_index as usize, - capture_count: m.capture_count as usize, - captures_ptr: m.captures, - cursor: PhantomData, - }) - } else { - None + loop { + unsafe { + let mut m = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) { + let m = m.assume_init(); + let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); + if self.captures_match_condition( + query, + captures, + m.pattern_index as usize, + &mut text_callback, + ) { + return Some(QueryMatch { + pattern_index: m.pattern_index as usize, + captures, + }); + } + } else { + return None; + } } } }) @@ -1260,9 +1266,7 @@ impl<'a> QueryMatch<'a> { } pub fn captures(&self) -> impl ExactSizeIterator { - let captures = - unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) }; - captures + self.captures .iter() .map(|capture| (capture.index as usize, Node::new(capture.node).unwrap())) } diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 599f4fb5..20c3fdbe 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -7,6 +7,10 @@ const SIZE_OF_RANGE = 2 * SIZE_OF_INT + 2 * SIZE_OF_POINT; const ZERO_POINT = {row: 0, column: 0}; const QUERY_WORD_REGEX = /[\w-.]*/g; +const PREDICATE_STEP_TYPE_DONE = 0; +const PREDICATE_STEP_TYPE_CAPTURE = 1; +const PREDICATE_STEP_TYPE_STRING = 2; + var VERSION; var MIN_COMPATIBLE_VERSION; var TRANSFER_BUFFER; @@ -661,21 +665,8 @@ class Language { TRANSFER_BUFFER, TRANSFER_BUFFER + SIZE_OF_INT ); - if (address) { - const captureCount = C._ts_query_capture_count(address); - const captureNames = new Array(captureCount); - for (let i = 0; i < captureCount; i++) { - const nameAddress = C._ts_query_capture_name_for_id( - address, - i, - TRANSFER_BUFFER - ); - const nameLength = getValue(TRANSFER_BUFFER, 'i32'); - captureNames[i] = UTF8ToString(nameAddress, nameLength); - } - C._free(sourceAddress); - return new Query(INTERNAL, address, captureNames); - } else { + + if (!address) { const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const errorByte = getValue(TRANSFER_BUFFER, 'i32'); const errorIndex = UTF8ToString(sourceAddress, errorByte).length; @@ -689,6 +680,9 @@ class Language { case 3: error = new RangeError(`Bad field name '${word}'`); break; + case 4: + error = new RangeError(`Bad capture name @${word}`); + break; default: error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`); break; @@ -698,6 +692,63 @@ class Language { C._free(sourceAddress); throw error; } + + const stringCount = C._ts_query_string_count(address); + const captureCount = C._ts_query_capture_count(address); + const patternCount = C._ts_query_pattern_count(address); + const captureNames = new Array(captureCount); + const stringValues = new Array(stringCount); + + for (let i = 0; i < captureCount; i++) { + const nameAddress = C._ts_query_capture_name_for_id( + address, + i, + TRANSFER_BUFFER + ); + const nameLength = getValue(TRANSFER_BUFFER, 'i32'); + captureNames[i] = UTF8ToString(nameAddress, nameLength); + } + + for (let i = 0; i < stringCount; i++) { + const valueAddress = C._ts_query_string_value_for_id( + address, + i, + TRANSFER_BUFFER + ); + const nameLength = getValue(TRANSFER_BUFFER, 'i32'); + stringValues[i] = UTF8ToString(valueAddress, nameLength); + } + + const predicates = new Array(patternCount); + for (let i = 0; i < patternCount; i++) { + const predicatesAddress = C._ts_query_predicates_for_pattern( + address, + i, + TRANSFER_BUFFER + ); + const stepCount = getValue(TRANSFER_BUFFER, 'i32'); + + predicates[i] = []; + const steps = []; + let stepAddress = predicatesAddress; + for (let j = 0; j < stepCount; j++) { + const stepType = getValue(stepAddress, 'i32'); + stepAddress += SIZE_OF_INT; + const stepValueId = getValue(stepAddress, 'i32'); + stepAddress += SIZE_OF_INT; + if (stepType === PREDICATE_STEP_TYPE_CAPTURE) { + steps.push({type: 'capture', name: captureNames[stepValueId]}); + } else if (stepType === PREDICATE_STEP_TYPE_STRING) { + steps.push({type: 'string', value: stringValues[stepValueId]}); + } else if (steps.length > 0) { + predicates[i].push(buildQueryPredicate(steps)); + steps.length = 0; + } + } + } + + C._free(sourceAddress); + return new Query(INTERNAL, address, captureNames, predicates); } static load(url) { @@ -733,10 +784,11 @@ class Language { } class Query { - constructor(internal, address, captureNames) { + constructor(internal, address, captureNames, predicates) { assertInternal(internal); this[0] = address; this.captureNames = captureNames; + this.predicates = predicates; } delete() { @@ -771,7 +823,9 @@ class Query { const captures = new Array(captureCount); address = unmarshalCaptures(this, node.tree, address, captures); - result[i] = {pattern, captures}; + if (this.predicates[pattern].every(p => p(captures))) { + result[i] = {pattern, captures}; + } } C._free(startAddress); @@ -809,7 +863,7 @@ class Query { const captures = new Array(captureCount); address = unmarshalCaptures(this, node.tree, address, captures); - if (capturesMatchConditions(this, node.tree, pattern, captures)) { + if (this.predicates[pattern].every(p => p(captures))) { result.push(captures[captureIndex]); } } @@ -819,8 +873,63 @@ class Query { } } -function capturesMatchConditions(query, tree, pattern, captures) { - return true; +function buildQueryPredicate(steps) { + if (steps[0].type !== 'string') { + throw new Error('Predicates must begin with a literal value'); + } + + switch (steps[0].value) { + case 'eq?': + if (steps.length !== 3) throw new Error( + `Wrong number of arguments to \`eq?\` predicate. Expected 2, got ${steps.length - 1}` + ); + if (steps[1].type !== 'capture') throw new Error( + `First argument of \`eq?\` predicate must be a capture. Got "${steps[1].value}"` + ); + if (steps[2].type === 'capture') { + const captureName1 = steps[1].name; + const captureName2 = steps[2].name; + return function(captures) { + let node1, node2 + for (const c of captures) { + if (c.name === captureName1) node1 = c.node; + if (c.name === captureName2) node2 = c.node; + } + return node1.text === node2.text + } + } else { + const captureName = steps[1].name; + const stringValue = steps[2].value; + return function(captures) { + for (const c of captures) { + if (c.name === captureName) return c.node.text === stringValue; + } + return false; + } + } + + case 'match?': + if (steps.length !== 3) throw new Error( + `Wrong number of arguments to \`match?\` predicate. Expected 2, got ${steps.length - 1}.` + ); + if (steps[1].type !== 'capture') throw new Error( + `First argument of \`match?\` predicate must be a capture. Got "${steps[1].value}".` + ); + if (steps[2].type !== 'string') throw new Error( + `Second argument of \`match?\` predicate must be a string. Got @${steps[2].value}.` + ); + const captureName = steps[1].name; + const regex = new RegExp(steps[2].value); + return function(captures) { + for (const c of captures) { + if (c.name === captureName) return regex.test(c.node.text); + } + return false; + } + + default: + throw new Error(`Unknown query predicate \`${steps[0].value}\``); + } } function unmarshalCaptures(query, tree, address, result) { diff --git a/lib/binding_web/exports.json b/lib/binding_web/exports.json index 6b0eab30..33fbad7a 100644 --- a/lib/binding_web/exports.json +++ b/lib/binding_web/exports.json @@ -70,12 +70,16 @@ "_ts_parser_set_language", "_ts_query_capture_count", "_ts_query_capture_name_for_id", + "_ts_query_captures_wasm", "_ts_query_context_delete", "_ts_query_context_new", "_ts_query_delete", "_ts_query_matches_wasm", - "_ts_query_captures_wasm", "_ts_query_new", + "_ts_query_pattern_count", + "_ts_query_predicates_for_pattern", + "_ts_query_string_count", + "_ts_query_string_value_for_id", "_ts_tree_cursor_current_field_id_wasm", "_ts_tree_cursor_current_node_id_wasm", "_ts_tree_cursor_current_node_is_missing_wasm", diff --git a/lib/binding_web/test/query-test.js b/lib/binding_web/test/query-test.js index 5d7ce620..06ef5370 100644 --- a/lib/binding_web/test/query-test.js +++ b/lib/binding_web/test/query-test.js @@ -19,7 +19,7 @@ describe("Query", () => { }); describe('construction', () => { - it('throws an error on invalid syntax', () => { + it('throws an error on invalid patterns', () => { assert.throws(() => { JavaScript.query("(function_declaration wat)") }, "Bad syntax at offset 22: \'wat)\'..."); @@ -33,6 +33,24 @@ describe("Query", () => { JavaScript.query("(function_declaration non_existent:(identifier))") }, "Bad field name 'non_existent'"); }); + + it('throws an error on invalid predicates', () => { + assert.throws(() => { + JavaScript.query("((identifier) @abc (eq? @ab hi))") + }, "Bad capture name @ab"); + assert.throws(() => { + JavaScript.query("((identifier) @abc (eq? @ab hi))") + }, "Bad capture name @ab"); + assert.throws(() => { + JavaScript.query("((identifier) @abc (eq?))") + }, "Wrong number of arguments to `eq?` predicate. Expected 2, got 0"); + assert.throws(() => { + JavaScript.query("((identifier) @a (eq? @a @a @a))") + }, "Wrong number of arguments to `eq?` predicate. Expected 2, got 3"); + assert.throws(() => { + JavaScript.query("((identifier) @a (something-else? @a))") + }, "Unknown query predicate `something-else?`"); + }); }); describe('.matches', () => { @@ -119,6 +137,66 @@ describe("Query", () => { ] ); }); + + it('handles conditions that compare the text of capture to literal strings', () => { + tree = parser.parse(` + const ab = require('./ab'); + new Cd(EF); + `); + + query = JavaScript.query(` + (identifier) @variable + + ((identifier) @function.builtin + (eq? @function.builtin "require")) + + ((identifier) @constructor + (match? @constructor "^[A-Z]")) + + ((identifier) @constant + (match? @constant "^[A-Z]{2,}$")) + `); + + const captures = query.captures(tree.rootNode); + assert.deepEqual( + formatCaptures(captures), + [ + {name: "variable", text: "ab"}, + {name: "variable", text: "require"}, + {name: "function.builtin", text: "require"}, + {name: "variable", text: "Cd"}, + {name: "constructor", text: "Cd"}, + {name: "variable", text: "EF"}, + {name: "constructor", text: "EF"}, + {name: "constant", text: "EF"}, + ] + ); + }); + + it('handles conditions that compare the text of capture to each other', () => { + tree = parser.parse(` + const ab = abc + 1; + const def = de + 1; + const ghi = ghi + 1; + `); + + query = JavaScript.query(` + ((variable_declarator + name: (identifier) @id1 + value: (binary_expression + left: (identifier) @id2)) + (eq? @id1 @id2)) + `); + + const captures = query.captures(tree.rootNode); + assert.deepEqual( + formatCaptures(captures), + [ + {name: "id1", text: "ghi"}, + {name: "id2", text: "ghi"}, + ] + ); + }); }); }); diff --git a/lib/src/query.c b/lib/src/query.c index 7a90b5eb..0a03a753 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -481,6 +481,10 @@ static TSQueryError ts_query_parse_predicate( })); } + else { + return TSQueryErrorSyntax; + } + step_count++; stream_skip_whitespace(stream); }