diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index b6d669c9..ff4dd7a1 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -56,8 +56,12 @@ fn test_query_errors_on_invalid_symbols() { Err(QueryError::NodeType("non_existent3")) ); assert_eq!( - Query::new(language, "(if_statement not_a_field: (identifier))"), - Err(QueryError::Field("not_a_field")) + Query::new(language, "(if_statement condit: (identifier))"), + Err(QueryError::Field("condit")) + ); + assert_eq!( + Query::new(language, "(if_statement conditioning: (identifier))"), + Err(QueryError::Field("conditioning")) ); }); } @@ -368,6 +372,67 @@ fn test_query_exec_within_byte_range() { }); } +#[test] +fn test_query_exec_different_queries() { + allocations::record(|| { + let language = get_language("javascript"); + let query1 = Query::new( + language, + " + (array (identifier) @id1) + ", + ) + .unwrap(); + let query2 = Query::new( + language, + " + (array (identifier) @id1) + (pair (identifier) @id2) + ", + ) + .unwrap(); + let query3 = Query::new( + language, + " + (array (identifier) @id1) + (pair (identifier) @id2) + (parenthesized_expression (identifier) @id3) + ", + ) + .unwrap(); + + let source = "[a, {b: b}, (c)];"; + + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); + + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + + let matches = cursor.exec(&query1, tree.root_node()); + assert_eq!( + collect_matches(matches, &query1, source), + &[(0, vec![("id1", "a")]),] + ); + + let matches = cursor.exec(&query3, tree.root_node()); + assert_eq!( + collect_matches(matches, &query3, source), + &[ + (0, vec![("id1", "a")]), + (1, vec![("id2", "b")]), + (2, vec![("id3", "c")]), + ] + ); + + let matches = cursor.exec(&query2, tree.root_node()); + assert_eq!( + collect_matches(matches, &query2, source), + &[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),] + ); + }); +} + #[test] fn test_query_capture_names() { allocations::record(|| { diff --git a/lib/binding_web/binding.c b/lib/binding_web/binding.c index 501cf1cd..db4449a2 100644 --- a/lib/binding_web/binding.c +++ b/lib/binding_web/binding.c @@ -567,15 +567,25 @@ int ts_node_is_missing_wasm(const TSTree *tree) { /* Section - Query */ /******************/ -void ts_query_exec_wasm(const TSQuery *self, const TSTree *tree) { +void ts_query_exec_wasm( + const TSQuery *self, + const TSTree *tree, + uint32_t start_row, + uint32_t start_column, + uint32_t end_row, + uint32_t end_column +) { if (!scratch_query_cursor) scratch_query_cursor = ts_query_cursor_new(); TSNode node = unmarshal_node(tree); + TSPoint start_point = {start_row, code_unit_to_byte(start_column)}; + TSPoint end_point = {end_row, code_unit_to_byte(end_column)}; Array(const void *) result = array_new(); unsigned index = 0; unsigned match_count = 0; + ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point); ts_query_cursor_exec(scratch_query_cursor, self, node); while (ts_query_cursor_next(scratch_query_cursor)) { match_count++; @@ -586,7 +596,7 @@ void ts_query_exec_wasm(const TSQuery *self, const TSTree *tree) { &capture_count ); - array_grow_by(&result, 1 + 6 * capture_count); + array_grow_by(&result, 2 + 6 * capture_count); result.contents[index++] = (const void *)pattern_index; result.contents[index++] = (const void *)capture_count; diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 85bd4053..0fd1ea63 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -688,24 +688,30 @@ class Language { const nameLength = getValue(TRANSFER_BUFFER, 'i32'); captureNames[i] = UTF8ToString(nameAddress, nameLength); } + C._free(sourceAddress); return new Query(INTERNAL, address, captureNames); } else { const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); - const utf8ErrorOffset = getValue(TRANSFER_BUFFER, 'i32'); - const errorOffset = UTF8ToString(sourceAddress, utf8ErrorOffset).length; - C._free(sourceAddress); - const suffix = source.slice(errorOffset, 100); + const errorByte = getValue(TRANSFER_BUFFER, 'i32'); + const errorIndex = UTF8ToString(sourceAddress, errorByte).length; + const suffix = source.slice(errorIndex, 100); + const word = suffix.match(QUERY_WORD_REGEX)[0]; + let error; switch (errorId) { - case 2: throw new RangeError( - `Bad node name '${suffix.match(QUERY_WORD_REGEX)[0]}'` - ); - case 3: throw new RangeError( - `Bad field name '${suffix.match(QUERY_WORD_REGEX)[0]}'` - ); - default: throw new SyntaxError( - `Bad syntax at offset ${errorOffset}: '${suffix}'...` - ); + case 2: + error = new RangeError(`Bad node name '${word}'`); + break; + case 3: + error = new RangeError(`Bad field name '${word}'`); + break; + default: + error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`); + break; } + error.index = errorIndex; + error.length = word.length; + C._free(sourceAddress); + throw error; } } @@ -752,10 +758,20 @@ class Query { C._ts_query_delete(this[0]); } - exec(queryNode) { + exec(queryNode, startPosition, endPosition) { + if (!startPosition) startPosition = ZERO_POINT; + if (!endPosition) endPosition = ZERO_POINT; + marshalNode(queryNode); - C._ts_query_exec_wasm(this[0], queryNode.tree[0]); + C._ts_query_exec_wasm( + this[0], + queryNode.tree[0], + startPosition.row, + startPosition.column, + endPosition.row, + endPosition.column + ); const matchCount = getValue(TRANSFER_BUFFER, 'i32'); const nodesAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); diff --git a/lib/binding_web/test/query-test.js b/lib/binding_web/test/query-test.js index cfa45bd4..4fd44165 100644 --- a/lib/binding_web/test/query-test.js +++ b/lib/binding_web/test/query-test.js @@ -41,10 +41,7 @@ describe("Query", () => { `); const matches = query.exec(tree.rootNode); assert.deepEqual( - matches.map(({pattern, captures}) => ({ - pattern, - captures: captures.map(({name, node}) => ({name, text: node.text})) - })), + formatMatches(matches), [ {pattern: 0, captures: [{name: 'fn-def', text: 'one'}]}, {pattern: 1, captures: [{name: 'fn-ref', text: 'two'}]}, @@ -52,4 +49,33 @@ describe("Query", () => { ] ); }); + + it('matches queries in specified ranges', () => { + tree = parser.parse("[a, b,\nc, d,\ne, f,\ng, h]"); + query = JavaScript.query('(identifier) @element'); + const matches = query.exec( + tree.rootNode, + {row: 1, column: 1}, + {row: 3, column: 1} + ); + assert.deepEqual( + formatMatches(matches), + [ + {pattern: 0, captures: [{name: 'element', text: 'd'}]}, + {pattern: 0, captures: [{name: 'element', text: 'e'}]}, + {pattern: 0, captures: [{name: 'element', text: 'f'}]}, + {pattern: 0, captures: [{name: 'element', text: 'g'}]}, + ] + ); + }); }); + +function formatMatches(matches) { + return matches.map(({pattern, captures}) => ({ + pattern, + captures: captures.map(({name, node}) => ({ + name, + text: node.text + })) + })) +} diff --git a/lib/src/language.c b/lib/src/language.c index 1bfb1a8d..e96a3cbf 100644 --- a/lib/src/language.c +++ b/lib/src/language.c @@ -96,7 +96,8 @@ TSFieldId ts_language_field_id_for_name( for (TSSymbol i = 1; i < count + 1; i++) { switch (strncmp(name, self->field_names[i], name_length)) { case 0: - return i; + if (self->field_names[i][name_length] == 0) return i; + break; case -1: return 0; default: diff --git a/lib/src/query.c b/lib/src/query.c index 323abebc..76f0d672 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -783,6 +783,7 @@ bool ts_query_cursor_next(TSQueryCursor *self) { .step_index = slice->step_index, .pattern_index = slice->pattern_index, .capture_list_id = capture_list_id, + .capture_count = 0, })); }